A quesiton about load_state_dict
pretrained_weights_path = 'lightvit_tiny_78.7.ckpt' pretrained_state_dict = torch.load(pretrained_weights_path) lightvit = lightvit_tiny(pretrained=False) lightvit.load_state_dict(pretrained_state_dict)
I use about code to use pretrained lightvit_tiny model, but it's useless, how can i pretrained lightvit_tiny model correctly?
you may use
lightvit.load_state_dict(pretrained_state_dict["state_dict"])
Can you provide more detailed error logs?
Can you provide more detailed error logs?
Thank you! I successfully loaded the pre-trained model, but this error occurs when I use the following code: to see the form of the tensor output before the pooling layer: pretrained_weights_path = 'lightvit_tiny_78.7.ckpt' pretrained_state_dict = torch.load(pretrained_weights_path) lightvit = lightvit_tiny(pretrained=False) lightvit.load_state_dict(pretrained_state_dict["state_dict"]) featureslightViT = list(lightvit.children())[:-1] self.backbone = nn.Sequential(*featureslightViT) TypeError: forward() missing 2 required positional arguments: 'H' and 'W', my inputs are inputs = torch.rand((1, 3, 224, 224))
You cannot simply wrap the children to nn.Sequential since some blocks require H and W as inputs, please refer to forward_features:
https://github.com/hunto/image_classification_sota/blob/36539b63cc8b851bd3fc93251bba60528813bb36/lib/models/lightvit.py#L384
You cannot simply wrap the children to nn.Sequential since some blocks require
HandWas inputs, please refer toforward_features: https://github.com/hunto/image_classification_sota/blob/36539b63cc8b851bd3fc93251bba60528813bb36/lib/models/lightvit.py#L384
I see in the diagram of the paper that the average pooling layer is in the HEAD section, does looking at this code mean that the pooling operation is not necessarily performed? self.head = nn.Linear(neck_dim, num_classes) if num_classes > 0 else nn.Identity()