mage icon indicating copy to clipboard operation
mage copied to clipboard

Question on Masking ratio

Open sukun1045 opened this issue 2 years ago • 5 comments

Hi, thanks for sharing the Pytorch implementation! I am curious about how you select the stats for varied masking ratios. In the paper, you mentioned 'a truncated Gaussian distribution centered at 0.55, left truncated by 0.5, and right truncated by 1.' What is the motivation for using such a distribution? Why not use the cosine schedule as done in MaskGIT? Thank you!

sukun1045 avatar Dec 15 '23 22:12 sukun1045

Thanks for your interest! The masking ratio is left truncated by 0.5 so that we can always drop 50% of the input tokens in the ViT encoder, which largely saves computation (a similar idea as MAE). In Table 5 of the paper, we show ablations about the center and std of the Gaussian distribution. We also tried the cosine masking ratio scheduling similar to MaskGIT, and the performance is slightly worse.

LTH14 avatar Dec 16 '23 22:12 LTH14

Thanks for your reply! I am also curious about the training convergence and finding the best model between variants. I think the best eval loss could be varied according to different masking strategies. How can I find the best masking strategy when I conduct these experiments? For example, if I choose the truncnorm with mu=0.55 and std=0.25, should I run the training until it converges, check the FID score, and then run another experiments?

sukun1045 avatar Dec 19 '23 09:12 sukun1045

Our evaluation protocol is based on both FID and linear probing accuracy -- once we train a model with certain hyper-parameters, we evaluate it on ImageNet and pick the best hyper-parameters based on FID and linear probing.

LTH14 avatar Dec 19 '23 09:12 LTH14

Thanks again for your reply. Regarding linear probing, have you tried using CLS token output instead of average pooling the rest of the encoder output features? I saw that in your code, but I wondered how it might affect the performance.

sukun1045 avatar Dec 30 '23 19:12 sukun1045

We tried using CLS token. However, the performance is not very stable -- normally it achieves similar performance as average pooled features, but occasionally it gets very poor accuracy (~10%). Therefore we choose the global average pooling feature for stability.

LTH14 avatar Jan 02 '24 16:01 LTH14