Pytorch <2.1.0 can't load the checkpoints correctly
When I try to use lower version of Pytorch such as 2.0.0 ,1.13.0 and so on,the checkpoint "bootstapir_checkpoint.pt" doesn't work correctly.
The error shows below:
RuntimeError Traceback (most recent call last) Cell In[2], line 61 59 #model = torch.nn.DataParallel(model) 60 torch.backends.cudnn.benchmark = True ---> 61 model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')) 62 model = model.to(device) 63 model
File D:\Program Files\Anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for TAPIR: Missing key(s) in state_dict: "resnet_torch.initial_conv.bias", "resnet_torch.block_groups.0.blocks.0.proj_conv.bias", "resnet_torch.block_groups.0.blocks.0.conv_0.bias", "resnet_torch.block_groups.0.blocks.0.conv_1.bias", "resnet_torch.block_groups.0.blocks.1.conv_0.bias", "resnet_torch.block_groups.0.blocks.1.conv_1.bias", "resnet_torch.block_groups.1.blocks.0.proj_conv.bias", "resnet_torch.block_groups.1.blocks.0.conv_0.bias", "resnet_torch.block_groups.1.blocks.0.conv_1.bias", "resnet_torch.block_groups.1.blocks.1.conv_0.bias", "resnet_torch.block_groups.1.blocks.1.conv_1.bias", "resnet_torch.block_groups.2.blocks.0.proj_conv.bias", "resnet_torch.block_groups.2.blocks.0.conv_0.bias", "resnet_torch.block_groups.2.blocks.0.conv_1.bias", "resnet_torch.block_groups.2.blocks.1.conv_0.bias", "resnet_torch.block_groups.2.blocks.1.conv_1.bias", "resnet_torch.block_groups.3.blocks.0.proj_conv.bias", "resnet_torch.block_groups.3.blocks.0.conv_0.bias", "resnet_torch.block_groups.3.blocks.0.conv_1.bias", "resnet_torch.block_groups.3.blocks.1.conv_0.bias", "resnet_torch.block_groups.3.blocks.1.conv_1.bias".
Comparatively, I add the parameter "bias" which refers to Pytorch official Document in the function of torch.nn.Conv2D , torch.nn. LayerNorm, and change "model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')) " to "model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt'),False) " ,then the programe can run. However, the outputs are wrong.
How can I do next? I guess the loading function in lower version of Pytorch is different from whose >= 2.1.0 .
Subtle changes in the document of Pytorch2.1.0 as it adds
"torch._C._log_api_usage_metadata(
"torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
)"
in the end.
It is currently unable for me to find a breakthrough from the underlying code.
Is that the reason?
When I try to use lower version of Pytorch such as 2.0.0 ,1.13.0 and so on,the checkpoint "bootstapir_checkpoint.pt" doesn't work correctly.
The error shows below:
RuntimeError Traceback (most recent call last) Cell In[2], line 61 59 #model = torch.nn.DataParallel(model) 60 torch.backends.cudnn.benchmark = True ---> 61 model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')) 62 model = model.to(device) 63 model
File D:\Program Files\Anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for TAPIR: Missing key(s) in state_dict: "resnet_torch.initial_conv.bias", "resnet_torch.block_groups.0.blocks.0.proj_conv.bias", "resnet_torch.block_groups.0.blocks.0.conv_0.bias", "resnet_torch.block_groups.0.blocks.0.conv_1.bias", "resnet_torch.block_groups.0.blocks.1.conv_0.bias", "resnet_torch.block_groups.0.blocks.1.conv_1.bias", "resnet_torch.block_groups.1.blocks.0.proj_conv.bias", "resnet_torch.block_groups.1.blocks.0.conv_0.bias", "resnet_torch.block_groups.1.blocks.0.conv_1.bias", "resnet_torch.block_groups.1.blocks.1.conv_0.bias", "resnet_torch.block_groups.1.blocks.1.conv_1.bias", "resnet_torch.block_groups.2.blocks.0.proj_conv.bias", "resnet_torch.block_groups.2.blocks.0.conv_0.bias", "resnet_torch.block_groups.2.blocks.0.conv_1.bias", "resnet_torch.block_groups.2.blocks.1.conv_0.bias", "resnet_torch.block_groups.2.blocks.1.conv_1.bias", "resnet_torch.block_groups.3.blocks.0.proj_conv.bias", "resnet_torch.block_groups.3.blocks.0.conv_0.bias", "resnet_torch.block_groups.3.blocks.0.conv_1.bias", "resnet_torch.block_groups.3.blocks.1.conv_0.bias", "resnet_torch.block_groups.3.blocks.1.conv_1.bias".
Comparatively, I add the parameter "bias" which refers to Pytorch official Document in the function of torch.nn.Conv2D , torch.nn. LayerNorm, and change "model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')) " to "model.load_state_dict(torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt'),False) " ,then the programe can run. However, the outputs are wrong. How can I do next? I guess the loading function in lower version of Pytorch is different from whose >= 2.1.0 . Subtle changes in the document of Pytorch2.1.0 as it adds "torch._C._log_api_usage_metadata( "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} )" in the end. It is currently unable for me to find a breakthrough from the underlying code. Is that the reason?
I solved this problem. When encountering errors while loading checkpoints, you can refer to the following code!
The processing method shows below:
model = tapir_model.TAPIR(pyramid_level=1)
state_dict = torch.load('tapnet/checkpoints/bootstapir_checkpoint.pt')
try:
model.load_state_dict(state_dict)
except:
TAPIR_state_dict_keys = ["resnet_torch.initial_conv.bias",
"resnet_torch.block_groups.0.blocks.0.proj_conv.bias",
"resnet_torch.block_groups.0.blocks.0.conv_0.bias",
"resnet_torch.block_groups.0.blocks.0.conv_1.bias",
"resnet_torch.block_groups.0.blocks.1.conv_0.bias",
"resnet_torch.block_groups.0.blocks.1.conv_1.bias",
"resnet_torch.block_groups.1.blocks.0.proj_conv.bias",
"resnet_torch.block_groups.1.blocks.0.conv_0.bias",
"resnet_torch.block_groups.1.blocks.0.conv_1.bias",
"resnet_torch.block_groups.1.blocks.1.conv_0.bias",
"resnet_torch.block_groups.1.blocks.1.conv_1.bias",
"resnet_torch.block_groups.2.blocks.0.proj_conv.bias",
"resnet_torch.block_groups.2.blocks.0.conv_0.bias",
"resnet_torch.block_groups.2.blocks.0.conv_1.bias",
"resnet_torch.block_groups.2.blocks.1.conv_0.bias",
"resnet_torch.block_groups.2.blocks.1.conv_1.bias",
"resnet_torch.block_groups.3.blocks.0.proj_conv.bias",
"resnet_torch.block_groups.3.blocks.0.conv_0.bias",
"resnet_torch.block_groups.3.blocks.0.conv_1.bias",
"resnet_torch.block_groups.3.blocks.1.conv_0.bias",
"resnet_torch.block_groups.3.blocks.1.conv_1.bias"]
for key in list(state_dict.keys()):
for i,item in enumerate(TAPIR_state_dict_keys):
if key.replace('.weight','') == item.replace('.bias',''):
state_dict[item] = torch.zeros(size = (state_dict[key].size()[0],))
TAPIR_state_dict_keys.pop(i)
break
#You can replace that checkpoint after adding some keys.
#torch.save(state_dict,'tapnet/checkpoints/bootstapir_checkpoint.pt')
model.load_state_dict(state_dict)
That's it!