unilm icon indicating copy to clipboard operation
unilm copied to clipboard

extending VLMO with MIM (Masked Image Modeling) loss

Open jinxixiang opened this issue 3 years ago • 10 comments

Thank you for sharing the source code of VLMO recently.

We took a stab and pretrained a large (1024 hidden dim) multiway transformer with mim loss, mlm loss, and contrastive loss.

BEIT3 pretrained mlm + mim loss and then intermediate finetune a contrastive loss for image-text retrieval. Our general idea is to incorporate two stages into one. We coded this up based on VLMO and BEIT2, but the results seemed to be surprising.

The finding is that masked loss seems to be contradictory with contrastive loss.

Settings: We use the vision expert branch in the multiway transformer to conduct the evaluation. Both models do not start from scratch but from the same weight. Each epoch contains 1 million images, 1 million texts, and 1 million image-text pairs.

We post the imagenet1k KNN classification results here.

using MIM + MLM + contrastive loss: (does not converge)

  • 9 epoch: Top1: 63.72, Top5: 85.37
  • 19 epoch: Top1: 62.768, Top5: 84.488
  • 29 epoch: Top1: 62.25, Top5: 84.302
  • 39 epoch: Top1: 62.4, Top5: 84.318
  • 129 epoch: Top1: 62.928, Top5: 85.136

using MIM + MLM loss: ( same as BEIT3)

  • epoch 9: Top1: 63.92, Top5: 85.316
  • epoch 19: Top1: 65.546, Top5: 86.154
  • epoch 39: Top1: 66.416, Top5: 86.966
  • epoch 49: Top1: 67.6, Top5: 87.606

The results indicate masked loss converges as expected whereas combing contrastive loss does not help.

I wonder whether you encounter similar problems before. Or probably provide any insights concerning the results?

Thank you.

Best regards

jinxixiang avatar Jan 06 '23 04:01 jinxixiang

@jinxixiang Could you also post the loss curves (such as tensorboard screenshots) of the run using MIM + MLM + contrastive loss: (does not converge)?

donglixp avatar Jan 06 '23 04:01 donglixp

Thank you for your help!

The training loss and accuracy of masked prediction are attached.

acc_with_contrastive

notes:

  • i2t_train_acc, t2i_train_acc: contrastive top1 acc
  • mim_image_train_acc: monomodal image acc
  • mim_train_acc: vl-ffn image acc
  • mlm_train_acc: vl-ffn text acc
  • mlm_text_train_acc: monomodal text acc
loss_with_contrastive

notes:

  • itc_loss: contrastive loss
  • loss: total loss, i.e., sum of mim+mlm+contrastive
  • mim_image_loss: mim loss of the monomodal image branch
  • mim_loss: mim loss of the vl-ffn branch
  • mlm_loss: mlm loss of vl-ffn branch
  • mlm_text_loss: mlm loss of the monomodal text branch

jinxixiang avatar Jan 06 '23 05:01 jinxixiang

And the plot of MIM + MLM loss: ( same as BEIT3)

loss_no_contrastive acc_no_contrastive

jinxixiang avatar Jan 06 '23 05:01 jinxixiang

Hi @jinxixiang, May I know the batch size you used for training? Maybe you can also remove contrastive loss on VL-FFN to make it simple.

wenhui0924 avatar Jan 06 '23 05:01 wenhui0924

we set batch size = 1024.

How does the contrastive loss on the VL-FFN help? since we only use the V-FFN and L-FFN to compute cosine similarity for retrieval.

jinxixiang avatar Jan 06 '23 06:01 jinxixiang

image From your tensorborad, I found vl_i2t and vl_t2i. It can slightly improve the model but it is not very important.

wenhui0924 avatar Jan 06 '23 06:01 wenhui0924

ok, thank you for your advice. I followed the implementation of contrastive loss from VLMO.

But maybe vl_i2t and vl_t2i are not the main reasons to prevent convergence?

Also, I found the accuracy of contrastive loss probably too high (>0.8). Maybe due to the small batch size.

For reference, what's the contrastive training accuracy at the intermediate fine-tune stage with batch size = 65536?

jinxixiang avatar Jan 06 '23 06:01 jinxixiang

You could try https://github.com/microsoft/torchscale if the issue is training stability (i.e., loss divergence).

The Multiway architecture can be enabled by multiway=True. https://github.com/microsoft/torchscale#key-features

donglixp avatar Jan 06 '23 09:01 donglixp

Thank you for your reply.

torchscale is a helpful toolkit for large model training, and we are happy to try it out later.

But I suppose that the issue is not training stability, as the loss does not diverge. Also the model with mim+mlm loss works just fine.

jinxixiang avatar Jan 06 '23 10:01 jinxixiang

The code and pre-trained models of BEiT-3 can be found at aka.ms/beit3.

donglixp avatar Mar 13 '23 13:03 donglixp