vision icon indicating copy to clipboard operation
vision copied to clipboard

Add convit into models

Open triple-mu opened this issue 3 years ago • 1 comments

Add convit into models

triple-mu avatar Jun 22 '22 03:06 triple-mu

对convit分类网络进行测试,不进行预训练,convit_tiny、convit_small、convit_base皆可正常运行。 当进行预训练时,由于网络是在imagenet数据集上进行的预训练,所以最后的分类输出种类是1000。采用的测试集案例为10分类,所以要修改最后一个Linear层的输出通道数为10。一般的分类网络vgg、alexnet等最后一层为classifier层。所以常规的使用预训练的代码为:

import oneflow as flow
from flowvision.models import ModelCreator     

net = ModelCreator.create_model("alexnet", pretrained=True)
num_fc = net.classifier[6].in_features
net.classifier[6] = flow.nn.Linear(in_features=num_fc, out_features=10)

在测试过程中,convit分类网络没有使用classifier层,取而代之的是head层,所以在使用时需要注意在加载预训练模型时需对以上常规代码进行修改,修改后可正常运行。 修改为:

net = ModelCreator.create_model("convit_tiny",pretrained = True)
num_fc = net.head.in_features
net.head = torch.nn.Linear(in_features=num_fc, out_features=10)

wzy9813125 avatar Aug 03 '22 10:08 wzy9813125