big_transfer icon indicating copy to clipboard operation
big_transfer copied to clipboard

TF2 weights shape wrong for 2 architectures

Open shkarupa-alex opened this issue 5 years ago • 2 comments

I've tried to load all provided TF2 model weights and found that 2 of them could not be loaded:

  • BiT-S-R50x3.h5 : ValueError: Cannot assign to variable standardized_conv2d/kernel:0 due to variable shape (7, 7, 3, 192) and value shape (7, 7, 3, 64) are incompatible
  • BiT-S-R101x1.h5 : Cannot assign to variable standardized_conv2d/kernel:0 due to variable shape (7, 7, 3, 64) and value shape (7, 7, 3, 256) are incompatible

All other weights loaded without problems. Sample code to reproduce: https://colab.research.google.com/drive/1s2QtVgrj2HrDs64xGMi_GsOaFR3i95v0?usp=sharing

shkarupa-alex avatar Jan 14 '21 11:01 shkarupa-alex

I actually found none of the provided BiT-S-* models to be compatible - i.e. they all cause an incompatible shape exception when e.g. trying to fine-tune on CIFAR10 dataset (while the BiT-M-* models work flawlessly!). Is this a known issue? Is there anything that needs to be adjusted?

@shkarupa-alex when I use the suggested command to fine-tune on CIFAR10: python3 -m bit_tf2.train --name cifar10_date +%F_%H%M%S --model BiT-S-R50x1 --logdir /tmp/bit_logs --dataset cifar10, different from your findings, I receive a ValueError: Shapes (2048, 21843) and (2048, 1000) are incompatible.

chrstn-hntschl avatar Jun 25 '21 13:06 chrstn-hntschl

My bad for not seeing the obvious: the difference in both shapes is due to the BiT-M* base models expecting 21843 outputs (i.e. number of classes in ImageNet-21k), whereas BiT-S* models expect 1000 outputs (i.e. number of classes in ILSVRC2012). In the current implementation, num_outputs is hardcoded to num_outputs=21843. This needs to be selected based on the used pre-trained model, e.g. by adding

NUM_OUTPUTS = {
    k: 1000 if "-S-" in k else
       21843
    for k in KNOWN_MODELS
}

in https://github.com/google-research/big_transfer/blob/master/bit_tf2/train.py.

chrstn-hntschl avatar Jun 28 '21 08:06 chrstn-hntschl