LightViT icon indicating copy to clipboard operation
LightViT copied to clipboard

A quesiton about load_state_dict

Open shenyehui opened this issue 2 years ago • 5 comments

pretrained_weights_path = 'lightvit_tiny_78.7.ckpt' pretrained_state_dict = torch.load(pretrained_weights_path) lightvit = lightvit_tiny(pretrained=False) lightvit.load_state_dict(pretrained_state_dict)

I use about code to use pretrained lightvit_tiny model, but it's useless, how can i pretrained lightvit_tiny model correctly?

shenyehui avatar Mar 28 '24 07:03 shenyehui

you may use

lightvit.load_state_dict(pretrained_state_dict["state_dict"])

hunto avatar Mar 28 '24 07:03 hunto

Can you provide more detailed error logs?

hunto avatar Mar 28 '24 07:03 hunto

Can you provide more detailed error logs?

Thank you! I successfully loaded the pre-trained model, but this error occurs when I use the following code: to see the form of the tensor output before the pooling layer: pretrained_weights_path = 'lightvit_tiny_78.7.ckpt' pretrained_state_dict = torch.load(pretrained_weights_path) lightvit = lightvit_tiny(pretrained=False) lightvit.load_state_dict(pretrained_state_dict["state_dict"]) featureslightViT = list(lightvit.children())[:-1] self.backbone = nn.Sequential(*featureslightViT) TypeError: forward() missing 2 required positional arguments: 'H' and 'W', my inputs are inputs = torch.rand((1, 3, 224, 224))

shenyehui avatar Mar 28 '24 07:03 shenyehui

You cannot simply wrap the children to nn.Sequential since some blocks require H and W as inputs, please refer to forward_features: https://github.com/hunto/image_classification_sota/blob/36539b63cc8b851bd3fc93251bba60528813bb36/lib/models/lightvit.py#L384

hunto avatar Mar 28 '24 08:03 hunto

You cannot simply wrap the children to nn.Sequential since some blocks require H and W as inputs, please refer to forward_features: https://github.com/hunto/image_classification_sota/blob/36539b63cc8b851bd3fc93251bba60528813bb36/lib/models/lightvit.py#L384

I see in the diagram of the paper that the average pooling layer is in the HEAD section, does looking at this code mean that the pooling operation is not necessarily performed? self.head = nn.Linear(neck_dim, num_classes) if num_classes > 0 else nn.Identity()

shenyehui avatar Mar 28 '24 08:03 shenyehui