tapnet icon indicating copy to clipboard operation
tapnet copied to clipboard

Pytorch <2.1.0 can't load the checkpoints correctly

Open shiyi099 opened this issue 1 year ago • 1 comments

When I try to use lower version of Pytorch such as 2.0.0 ,1.13.0 and so on,the checkpoint "bootstapir_checkpoint.pt" doesn't work correctly.

The error shows below:

RuntimeError Traceback (most recent call last) Cell In[2], line 61 59 #model = torch.nn.DataParallel(model) 60 torch.backends.cudnn.benchmark = True ---> 61 model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')) 62 model = model.to(device) 63 model

File D:\Program Files\Anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for TAPIR: Missing key(s) in state_dict: "resnet_torch.initial_conv.bias", "resnet_torch.block_groups.0.blocks.0.proj_conv.bias", "resnet_torch.block_groups.0.blocks.0.conv_0.bias", "resnet_torch.block_groups.0.blocks.0.conv_1.bias", "resnet_torch.block_groups.0.blocks.1.conv_0.bias", "resnet_torch.block_groups.0.blocks.1.conv_1.bias", "resnet_torch.block_groups.1.blocks.0.proj_conv.bias", "resnet_torch.block_groups.1.blocks.0.conv_0.bias", "resnet_torch.block_groups.1.blocks.0.conv_1.bias", "resnet_torch.block_groups.1.blocks.1.conv_0.bias", "resnet_torch.block_groups.1.blocks.1.conv_1.bias", "resnet_torch.block_groups.2.blocks.0.proj_conv.bias", "resnet_torch.block_groups.2.blocks.0.conv_0.bias", "resnet_torch.block_groups.2.blocks.0.conv_1.bias", "resnet_torch.block_groups.2.blocks.1.conv_0.bias", "resnet_torch.block_groups.2.blocks.1.conv_1.bias", "resnet_torch.block_groups.3.blocks.0.proj_conv.bias", "resnet_torch.block_groups.3.blocks.0.conv_0.bias", "resnet_torch.block_groups.3.blocks.0.conv_1.bias", "resnet_torch.block_groups.3.blocks.1.conv_0.bias", "resnet_torch.block_groups.3.blocks.1.conv_1.bias".

Comparatively, I add the parameter "bias" which refers to Pytorch official Document in the function of torch.nn.Conv2D , torch.nn. LayerNorm, and change "model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')) " to "model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt'),False) " ,then the programe can run. However, the outputs are wrong.
How can I do next? I guess the loading function in lower version of Pytorch is different from whose >= 2.1.0 . Subtle changes in the document of Pytorch2.1.0 as it adds "torch._C._log_api_usage_metadata( "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} )" in the end. It is currently unable for me to find a breakthrough from the underlying code. Is that the reason?

shiyi099 avatar Mar 28 '24 18:03 shiyi099

When I try to use lower version of Pytorch such as 2.0.0 ,1.13.0 and so on,the checkpoint "bootstapir_checkpoint.pt" doesn't work correctly.

The error shows below:

RuntimeError Traceback (most recent call last) Cell In[2], line 61 59 #model = torch.nn.DataParallel(model) 60 torch.backends.cudnn.benchmark = True ---> 61 model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')) 62 model = model.to(device) 63 model

File D:\Program Files\Anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for TAPIR: Missing key(s) in state_dict: "resnet_torch.initial_conv.bias", "resnet_torch.block_groups.0.blocks.0.proj_conv.bias", "resnet_torch.block_groups.0.blocks.0.conv_0.bias", "resnet_torch.block_groups.0.blocks.0.conv_1.bias", "resnet_torch.block_groups.0.blocks.1.conv_0.bias", "resnet_torch.block_groups.0.blocks.1.conv_1.bias", "resnet_torch.block_groups.1.blocks.0.proj_conv.bias", "resnet_torch.block_groups.1.blocks.0.conv_0.bias", "resnet_torch.block_groups.1.blocks.0.conv_1.bias", "resnet_torch.block_groups.1.blocks.1.conv_0.bias", "resnet_torch.block_groups.1.blocks.1.conv_1.bias", "resnet_torch.block_groups.2.blocks.0.proj_conv.bias", "resnet_torch.block_groups.2.blocks.0.conv_0.bias", "resnet_torch.block_groups.2.blocks.0.conv_1.bias", "resnet_torch.block_groups.2.blocks.1.conv_0.bias", "resnet_torch.block_groups.2.blocks.1.conv_1.bias", "resnet_torch.block_groups.3.blocks.0.proj_conv.bias", "resnet_torch.block_groups.3.blocks.0.conv_0.bias", "resnet_torch.block_groups.3.blocks.0.conv_1.bias", "resnet_torch.block_groups.3.blocks.1.conv_0.bias", "resnet_torch.block_groups.3.blocks.1.conv_1.bias".

Comparatively, I add the parameter "bias" which refers to Pytorch official Document in the function of torch.nn.Conv2D , torch.nn. LayerNorm, and change "model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')) " to "model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt'),False) " ,then the programe can run. However, the outputs are wrong. How can I do next? I guess the loading function in lower version of Pytorch is different from whose >= 2.1.0 . Subtle changes in the document of Pytorch2.1.0 as it adds "torch._C._log_api_usage_metadata( "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} )" in the end. It is currently unable for me to find a breakthrough from the underlying code. Is that the reason?

I solved this problem. When encountering errors while loading checkpoints, you can refer to the following code!

The processing method shows below:

model = tapir_model.TAPIR(pyramid_level=1)
state_dict = torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')
try:
    model.load_state_dict(state_dict)
except:
    TAPIR_state_dict_keys = ["resnet_torch.initial_conv.bias", 
                             "resnet_torch.block_groups.0.blocks.0.proj_conv.bias", 
                             "resnet_torch.block_groups.0.blocks.0.conv_0.bias", 
                             "resnet_torch.block_groups.0.blocks.0.conv_1.bias", 
                             "resnet_torch.block_groups.0.blocks.1.conv_0.bias", 
                             "resnet_torch.block_groups.0.blocks.1.conv_1.bias", 
                             "resnet_torch.block_groups.1.blocks.0.proj_conv.bias", 
                             "resnet_torch.block_groups.1.blocks.0.conv_0.bias", 
                             "resnet_torch.block_groups.1.blocks.0.conv_1.bias", 
                             "resnet_torch.block_groups.1.blocks.1.conv_0.bias", 
                             "resnet_torch.block_groups.1.blocks.1.conv_1.bias", 
                             "resnet_torch.block_groups.2.blocks.0.proj_conv.bias", 
                             "resnet_torch.block_groups.2.blocks.0.conv_0.bias", 
                             "resnet_torch.block_groups.2.blocks.0.conv_1.bias", 
                             "resnet_torch.block_groups.2.blocks.1.conv_0.bias", 
                             "resnet_torch.block_groups.2.blocks.1.conv_1.bias", 
                             "resnet_torch.block_groups.3.blocks.0.proj_conv.bias", 
                             "resnet_torch.block_groups.3.blocks.0.conv_0.bias", 
                             "resnet_torch.block_groups.3.blocks.0.conv_1.bias", 
                             "resnet_torch.block_groups.3.blocks.1.conv_0.bias", 
                             "resnet_torch.block_groups.3.blocks.1.conv_1.bias"]

    for key in list(state_dict.keys()):
        for i,item in enumerate(TAPIR_state_dict_keys):
            if key.replace('.weight','') == item.replace('.bias',''):
                state_dict[item] = torch.zeros(size = (state_dict[key].size()[0],))
                TAPIR_state_dict_keys.pop(i)
                break
    #You can replace that checkpoint after adding some keys.
    #torch.save(state_dict,'tapnet/checkpoints/bootstapir_checkpoint.pt')
model.load_state_dict(state_dict)

That's it!

shiyi099 avatar Mar 29 '24 11:03 shiyi099