mup icon indicating copy to clipboard operation
mup copied to clipboard

coord_check for model that returns loss function directly

Open ad8e opened this issue 2 years ago • 0 comments

Some transformers (like x-transformers) take in a sequence of length (seq_len+1), then split it into input=x[:-1] and target=x[1:], and calculate the loss directly in forward(). This is efficient because the input and targets overlap. It means that forward() returns the loss, rather than the targets.

It would be nice if coord_check had an option that supported this usecase, where forward() returns the loss directly. Like adding loss_from_forward to the function signatures, and inserting this:

                elif loss_from_forward:
                    if cuda:
                        batch = batch.cuda()
                    loss = model(batch)

at https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/coord_check.py#L317

ad8e avatar Jan 08 '24 22:01 ad8e