torch_maml
torch_maml copied to clipboard
关于模型的微调。
博主你好,我想问一下,在这里从验证集取数据,对模型进行微调,但是我们并没有更新元模型权重。这里的目的是什么呢?只是说单纯的验证一下模型的性能吗? for support_images, support_labels, query_images, query_labels in val_loader:
# Get variables
support_images = support_images.float().to(dev)
support_labels = support_labels.long().to(dev)
query_images = query_images.float().to(dev)
query_labels = query_labels.long().to(dev)
loss, acc = maml_train(model, support_images, support_labels, query_images, query_labels,
3, args, optimizer, is_train=False)
# Must use .item() to add total loss, or will occur GPU memory leak.
# Because dynamic graph is created during forward, collect in backward.
val_loss.append(loss.item())
val_acc.append(acc)