Bad performance on ImageNet variants
I ran the FLYP code to compare with "Masked Images Are Counterfactual Samples for Robust Fine-tuning, CVPR 2023", using ViT-B/32 model. I expect that FLYP can be competitive with other methods, but the performance on OOD datasets of model trained with FLYP is significantly degraded.
Zero-shot CLIP performance using ViT-B/32 is the following:
ImageNet Top-1 accuracy: 63.4
ImageNetV2 Top-1 accuracy: 55.9
ImageNetR Top-1 accuracy: 69.3
ImageNetSketch Top-1 accuracy: 42.3
ImageNetA Top-1 accuracy: 31.4
I ran just one epoch training with FLYP, but its performance is:
ImageNet Top-1 accuracy: 73.3
ImageNetV2 Top-1 accuracy: 62.6
ImageNetR Top-1 accuracy: 63.1
ImageNetSketch Top-1 accuracy: 40.9
ImageNetA Top-1 accuracy: 25.9
FLYP cannot preserve the robustness, and the performances on ImageNet-R, ImageNet Sketch, and ImageNet-A are dropped compared to Zero-shot CLIP, even just trained for an epoch. I use the same parameters that are used in training for ViT-B/16 experiments.
Can you clarify this phenomenon? Are there any wrong things in this experiment?
Hi, Can you check if you are able to get correct results with B/16? I personally have tried B/16, L/14, L/14-336px and even many other people in the community have tried FLYP and reported good performance. I can try running B/32 though after ICLR deadline.
But thanks for sharing these numbers. Could you also compare them with standard finetuning if you have those numbers as well.
It appears that weight ensemble is not applied in the current code. I cannot see the line where args.alpha is used. Is it normal for the results without weight ensembling to drop in performance in some datasets as shown above? Or, is the above result weird even if weight ensembling is not applied?
As mentioned in the comment, I will also experiment with ViT-B/16.
The results should be much better even without weight ensembling. I am not really sure about what the baseline for say standard cross entropy finetuning would be, but still FLYP should give better OOD accuracies than zeroshot (even without ensembling). I used the ensembling code from https://github.com/mlfoundations/wise-ft.
Let me know the ViTB/16 numbers once you get them and we can debug then
Hi, Were you able to check this on other models?
Sorry for the late reply,
I ran the FLYP with CLIP ViT/16 without ensembling (i.e., WiSE-FT) and I got the accuracy on ImageNet variants as follows:
ImageNet Top-1 accuracy: 82.4 ImageNetV2 Top-1 accuracy: 73.1 ImageNetR Top-1 accuracy: 70.7 ImageNetSketch Top-1 accuracy: 48.8 ObjectNet Top-1 accuracy: 54.2 ImageNetA Top-1 accuracy: 47.5 Avg OOD: 58.86
Relatively robustness is maintained compared to the experiment using ViT-B/32, but it seems to be lower overall than the performance reported in the paper (Avg OOD ours: 58.9, Reported: 60.2), especially ImageNet-R, A, and ObjectNet.
In particular, the score I obtained is not much different from the zero-shot OOD performance, and robustness is maintained. However, this is due to the increased performance of ImageNet-V2, which is relatively similar to in-distribution, but there is a significant performance drop in the remaining OOD dataset.
What should I modify in my experiment to get the reported score?
Did you use the CLI arguments in the readme? Can you please send me your logs.
Sorry for the late reply. Here are the arguments.
Namespace(data_location='./datasets/data/', eval_datasets=['ImageNet', 'ImageNetV2', 'ImageNetR', 'ImageNetSketch', 'ObjectNet', 'ImageNetA'], train_dataset='ImageNet', template='openai_imagenet_template', classnames='openai', alpha=[0.5], exp_name='ImageNet/flyp_loss_Re', results_db=None, model='ViT-B/16', batch_size=512, lr=1e-05, wd=0.1, warmup_length=500, num_classes=1000, epochs=10, load=None, save='expt_logs/ViT-B/16/ImageNet/flyp_loss_Re/_BS512_WD0.1_LR1e-05_run1', resume=None, cache_dir=None, fisher=None, fisher_floor=1e-08, ft_data='./datasets/csv/imagenet.csv', ce_ablation=None, dataset_type='auto', train_num_samples=None, k=None, seed=0, workers=4, csv_separator='\t', csv_img_key='filepath', csv_caption_key='title', precision='amp', clip_load=None, wise_save=None, run=1, get_labeled_csv=False, min_lr=1e-06, scheduler='cosine', device='cuda:0')
Below are the logs. ObjectNet dataset is bigger than other datasets, so I only do evaluate on ObjectNet after 8 epochs.
2023-10-23,14:31:17 | INFO | Train Epoch: 0 [ 512/1281167 (0%)] Data (t): 0.000 Batch (t): 5.934, 29.8708/s LR: 0.000000 Loss: 1.7685 (1.7685) 2023-10-23,14:32:17 | INFO | Train Epoch: 0 [ 51712/1281167 (4%)] Data (t): 0.000 Batch (t): 0.606, 211.832/s LR: 0.000002 Loss: 1.1258 (1.4472) 2023-10-23,14:33:16 | INFO | Train Epoch: 0 [ 102912/1281167 (8%)] Data (t): 0.000 Batch (t): 0.587, 217.514/s LR: 0.000004 Loss: 0.98057 (1.2916) 2023-10-23,14:34:15 | INFO | Train Epoch: 0 [ 154112/1281167 (12%)] Data (t): 0.000 Batch (t): 0.589, 214.717/s LR: 0.000006 Loss: 0.89656 (1.1929) 2023-10-23,14:35:14 | INFO | Train Epoch: 0 [ 205312/1281167 (16%)] Data (t): 0.000 Batch (t): 0.590, 216.794/s LR: 0.000008 Loss: 0.85278 (1.1248) 2023-10-23,14:36:13 | INFO | Train Epoch: 0 [ 256512/1281167 (20%)] Data (t): 0.000 Batch (t): 0.591, 216.612/s LR: 0.000011 Loss: 0.82030 (1.0741) 2023-10-23,14:37:12 | INFO | Train Epoch: 0 [ 307712/1281167 (24%)] Data (t): 0.000 Batch (t): 0.591, 216.511/s LR: 0.000011 Loss: 0.72410 (1.0241) 2023-10-23,14:38:11 | INFO | Train Epoch: 0 [ 358912/1281167 (28%)] Data (t): 0.000 Batch (t): 0.591, 216.918/s LR: 0.000011 Loss: 0.82420 (0.99911) 2023-10-23,14:39:13 | INFO | Train Epoch: 0 [ 410112/1281167 (32%)] Data (t): 0.000 Batch (t): 0.616, 216.763/s LR: 0.000011 Loss: 0.72191 (0.96831) 2023-10-23,14:40:17 | INFO | Train Epoch: 0 [ 461312/1281167 (36%)] Data (t): 0.000 Batch (t): 0.638, 216.873/s LR: 0.000011 Loss: 0.80310 (0.95179) 2023-10-23,14:41:16 | INFO | Train Epoch: 0 [ 512512/1281167 (40%)] Data (t): 0.000 Batch (t): 0.591, 217.051/s LR: 0.000011 Loss: 0.77836 (0.93602) 2023-10-23,14:42:15 | INFO | Train Epoch: 0 [ 563712/1281167 (44%)] Data (t): 0.000 Batch (t): 0.591, 216.691/s LR: 0.000011 Loss: 0.71075 (0.91725) 2023-10-23,14:43:14 | INFO | Train Epoch: 0 [ 614912/1281167 (48%)] Data (t): 0.000 Batch (t): 0.591, 216.219/s LR: 0.000011 Loss: 0.72596 (0.90253) 2023-10-23,14:44:13 | INFO | Train Epoch: 0 [ 666112/1281167 (52%)] Data (t): 0.000 Batch (t): 0.592, 216.635/s LR: 0.000011 Loss: 0.70613 (0.88850) 2023-10-23,14:45:13 | INFO | Train Epoch: 0 [ 717312/1281167 (56%)] Data (t): 0.000 Batch (t): 0.592, 216.582/s LR: 0.000011 Loss: 0.72703 (0.87774) 2023-10-23,14:46:12 | INFO | Train Epoch: 0 [ 768512/1281167 (60%)] Data (t): 0.000 Batch (t): 0.592, 216.336/s LR: 0.000011 Loss: 0.70736 (0.86709) 2023-10-23,14:47:11 | INFO | Train Epoch: 0 [ 819712/1281167 (64%)] Data (t): 0.000 Batch (t): 0.592, 216.345/s LR: 0.000011 Loss: 0.69498 (0.85697) 2023-10-23,14:48:15 | INFO | Train Epoch: 0 [ 870912/1281167 (68%)] Data (t): 0.000 Batch (t): 0.639, 216.354/s LR: 0.000011 Loss: 0.71259 (0.84894) 2023-10-23,14:49:16 | INFO | Train Epoch: 0 [ 922112/1281167 (72%)] Data (t): 0.000 Batch (t): 0.615, 216.367/s LR: 0.000011 Loss: 0.68842 (0.84050) 2023-10-23,14:50:15 | INFO | Train Epoch: 0 [ 973312/1281167 (76%)] Data (t): 0.000 Batch (t): 0.591, 216.571/s LR: 0.000011 Loss: 0.72542 (0.83474) 2023-10-23,14:51:15 | INFO | Train Epoch: 0 [1024512/1281167 (80%)] Data (t): 0.000 Batch (t): 0.591, 216.357/s LR: 0.000011 Loss: 0.64165 (0.82555) 2023-10-23,14:52:14 | INFO | Train Epoch: 0 [1075712/1281167 (84%)] Data (t): 0.000 Batch (t): 0.592, 215.842/s LR: 0.000011 Loss: 0.70868 (0.82024) 2023-10-23,14:53:13 | INFO | Train Epoch: 0 [1126912/1281167 (88%)] Data (t): 0.000 Batch (t): 0.592, 215.866/s LR: 0.000011 Loss: 0.69037 (0.81459) 2023-10-23,14:54:12 | INFO | Train Epoch: 0 [1178112/1281167 (92%)] Data (t): 0.000 Batch (t): 0.592, 216.268/s LR: 0.000011 Loss: 0.65542 (0.80796) 2023-10-23,14:55:11 | INFO | Train Epoch: 0 [1229312/1281167 (96%)] Data (t): 0.000 Batch (t): 0.592, 216.047/s LR: 0.000011 Loss: 0.64659 (0.80150) 2023-10-23,14:56:11 | INFO | Train Epoch: 0 [1280512/1281167 (100%)] Data (t): 0.000 Batch (t): 0.592, 216.380/s LR: 0.000011 Loss: 0.65731 (0.79596) 2023-10-23,14:58:32 | INFO | ImageNet Top-1 accuracy: 0.7842 2023-10-23,14:58:58 | INFO | ImageNetV2 Top-1 accuracy: 0.6980 2023-10-23,14:59:43 | INFO | ImageNetR Top-1 accuracy: 0.7086 2023-10-23,15:00:58 | INFO | ImageNetSketch Top-1 accuracy: 0.4644 2023-10-23,15:01:15 | INFO | ImageNetA Top-1 accuracy: 0.4428
...
2023-10-23,19:08:39 | INFO | Train Epoch: 9 [ 512/1281167 (0%)] Data (t): 0.000 Batch (t): 1.621, 121.220/s LR: 0.000001 Loss: 0.38631 (0.38631) 2023-10-23,19:09:39 | INFO | Train Epoch: 9 [ 51712/1281167 (4%)] Data (t): 0.000 Batch (t): 0.596, 217.674/s LR: 0.000001 Loss: 0.41418 (0.40024) 2023-10-23,19:10:43 | INFO | Train Epoch: 9 [ 102912/1281167 (8%)] Data (t): 0.000 Batch (t): 0.639, 216.902/s LR: 0.000001 Loss: 0.42194 (0.40748) 2023-10-23,19:11:48 | INFO | Train Epoch: 9 [ 154112/1281167 (12%)] Data (t): 0.000 Batch (t): 0.650, 217.253/s LR: 0.000001 Loss: 0.36600 (0.39711) 2023-10-23,19:12:47 | INFO | Train Epoch: 9 [ 205312/1281167 (16%)] Data (t): 0.000 Batch (t): 0.590, 217.111/s LR: 0.000001 Loss: 0.38669 (0.39503) 2023-10-23,19:13:46 | INFO | Train Epoch: 9 [ 256512/1281167 (20%)] Data (t): 0.000 Batch (t): 0.591, 216.462/s LR: 0.000001 Loss: 0.38174 (0.39281) 2023-10-23,19:14:45 | INFO | Train Epoch: 9 [ 307712/1281167 (24%)] Data (t): 0.000 Batch (t): 0.591, 216.506/s LR: 0.000001 Loss: 0.36417 (0.38872) 2023-10-23,19:15:44 | INFO | Train Epoch: 9 [ 358912/1281167 (28%)] Data (t): 0.000 Batch (t): 0.591, 216.277/s LR: 0.000001 Loss: 0.39689 (0.38974) 2023-10-23,19:16:43 | INFO | Train Epoch: 9 [ 410112/1281167 (32%)] Data (t): 0.000 Batch (t): 0.591, 217.085/s LR: 0.000001 Loss: 0.39641 (0.39048) 2023-10-23,19:17:42 | INFO | Train Epoch: 9 [ 461312/1281167 (36%)] Data (t): 0.000 Batch (t): 0.591, 216.068/s LR: 0.000001 Loss: 0.40005 (0.39144) 2023-10-23,19:18:41 | INFO | Train Epoch: 9 [ 512512/1281167 (40%)] Data (t): 0.000 Batch (t): 0.592, 216.362/s LR: 0.000001 Loss: 0.38854 (0.39117) 2023-10-23,19:19:45 | INFO | Train Epoch: 9 [ 563712/1281167 (44%)] Data (t): 0.000 Batch (t): 0.639, 217.170/s LR: 0.000001 Loss: 0.35694 (0.38832) 2023-10-23,19:20:49 | INFO | Train Epoch: 9 [ 614912/1281167 (48%)] Data (t): 0.000 Batch (t): 0.640, 216.749/s LR: 0.000001 Loss: 0.36357 (0.38642) 2023-10-23,19:21:48 | INFO | Train Epoch: 9 [ 666112/1281167 (52%)] Data (t): 0.000 Batch (t): 0.591, 216.520/s LR: 0.000001 Loss: 0.39252 (0.38685) 2023-10-23,19:22:47 | INFO | Train Epoch: 9 [ 717312/1281167 (56%)] Data (t): 0.000 Batch (t): 0.592, 216.614/s LR: 0.000001 Loss: 0.39923 (0.38768) 2023-10-23,19:23:47 | INFO | Train Epoch: 9 [ 768512/1281167 (60%)] Data (t): 0.000 Batch (t): 0.591, 216.704/s LR: 0.000001 Loss: 0.38745 (0.38766) 2023-10-23,19:24:46 | INFO | Train Epoch: 9 [ 819712/1281167 (64%)] Data (t): 0.000 Batch (t): 0.591, 216.514/s LR: 0.000001 Loss: 0.36276 (0.38620) 2023-10-23,19:25:45 | INFO | Train Epoch: 9 [ 870912/1281167 (68%)] Data (t): 0.000 Batch (t): 0.591, 216.088/s LR: 0.000001 Loss: 0.30742 (0.38182) 2023-10-23,19:26:44 | INFO | Train Epoch: 9 [ 922112/1281167 (72%)] Data (t): 0.000 Batch (t): 0.592, 216.054/s LR: 0.000001 Loss: 0.38636 (0.38206) 2023-10-23,19:27:43 | INFO | Train Epoch: 9 [ 973312/1281167 (76%)] Data (t): 0.000 Batch (t): 0.592, 216.018/s LR: 0.000001 Loss: 0.37076 (0.38150) 2023-10-23,19:28:52 | INFO | Train Epoch: 9 [1024512/1281167 (80%)] Data (t): 0.000 Batch (t): 0.690, 42.1471/s LR: 0.000001 Loss: 0.40817 (0.38277) 2023-10-23,19:29:51 | INFO | Train Epoch: 9 [1075712/1281167 (84%)] Data (t): 0.000 Batch (t): 0.591, 216.712/s LR: 0.000001 Loss: 0.36955 (0.38217) 2023-10-23,19:30:50 | INFO | Train Epoch: 9 [1126912/1281167 (88%)] Data (t): 0.000 Batch (t): 0.591, 216.513/s LR: 0.000001 Loss: 0.39451 (0.38270) 2023-10-23,19:31:49 | INFO | Train Epoch: 9 [1178112/1281167 (92%)] Data (t): 0.000 Batch (t): 0.591, 217.025/s LR: 0.000001 Loss: 0.39733 (0.38331) 2023-10-23,19:32:48 | INFO | Train Epoch: 9 [1229312/1281167 (96%)] Data (t): 0.000 Batch (t): 0.592, 216.189/s LR: 0.000001 Loss: 0.38191 (0.38326) 2023-10-23,19:33:48 | INFO | Train Epoch: 9 [1280512/1281167 (100%)] Data (t): 0.000 Batch (t): 0.592, 216.382/s LR: 0.000001 Loss: 0.34291 (0.38170) 2023-10-23,19:36:10 | INFO | ImageNet Top-1 accuracy: 0.8238 2023-10-23,19:36:33 | INFO | ImageNetV2 Top-1 accuracy: 0.7309 2023-10-23,19:37:21 | INFO | ImageNetR Top-1 accuracy: 0.7071 2023-10-23,19:38:36 | INFO | ImageNetSketch Top-1 accuracy: 0.4879 2023-10-23,19:41:51 | INFO | ObjectNet Top-1 accuracy: 0.5417 2023-10-23,19:42:08 | INFO | ImageNetA Top-1 accuracy: 0.4755 2023-10-23,19:42:08 | INFO | Saving model toexpt_logs/ViT-B/16/ImageNet/flyp_loss_Re/_BS512_WD0.1_LR1e-05_run1/checkpoint_10.pt 2023-10-23,19:42:10 | INFO | Avg OOD Acc : 0.5886