Implementation of a Masked Autoencoder for representation learning
This follows a previous PR (#7598).
In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model.
In the official masked autoencoder implementation, noise is first generated and then sorted twice using torch.argsort. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices.
In our implementation, we use torch.multinomial to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder.
Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.
Description
Implementation of the Masked Autoencoder as described in the paper: Masked Autoencoders Are Scalable Vision Learners from Kaiming et al.
Its effectiveness has already been demonstrated in the literature for medical tasks in the paper Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation. The PR contains the architecture and associated unit tests.
Note: The output includes the prediction, which is a tensor of size: ($BS$, $N_{tokens}$, $D$), and the associated mask ($BS$, $N_{tokens}$). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think?
Types of changes
- [x] Non-breaking change (fix or new feature that would not break existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running
./runtests.sh -f -u --net --coverage. - [x] Quick tests passed locally by running
./runtests.sh --quick --unittests --disttests. - [x] In-line docstrings updated.
- [x] Documentation updated, tested
make htmlcommand in thedocs/folder.
Hi @Lucas-rbnt thanks for the effort on this followup PR. @atbenmurray could you please re-review the content here?
@Lucas-rbnt @atbenmurray I shall do so
I think this is fine now though the comments should be looked at the conflict resolved, then we can trigger the blossom tests. Thanks!
/build