torch_maml
torch_maml copied to clipboard
Very simple pytorch maml implement
以CPU模式运行的话,helper.py中 > model = Classifier(1, args.n_way).cuda() 要把.cuda()调用改为.to(dev),其中dev是在args.py中设置的设备变量,确保dev被设置为'cpu'。 arg.py中需要: - 注释--gpu参数,因为在使用CPU进行训练时不需要指定GPU设备。 - 注释设置环境变量CUDA_VISIBLE_DEVICES的代码,因为在使用CPU时不需要这一步骤。 - 修改dev变量的赋值,直接将其设置为torch.device('cpu'),确保所有操作都在CPU上执行。
In maml.py line 105 and 106, the code is ```python fast_weights = collections.OrderedDict((name, param - args.inner_lr * grads) for ((name, param), grads) in zip(fast_weights.items(), grads)) ``` is there any wrong...
跑你的例程都没动 val_loss: nan val_acc: nan 为啥这2个值都是nan呢?
首先非常感谢分享代码,本人刚开始学meta leanring , 找资料的时候找到了您写的blog,按照提示找到了这个pytorch 版本的代码 但是发现这个代码只有train.py 没有 test 的代码 请问有 pytorch 版本的 test.py 代码吗? 有的话可以分享吗 ?
关于测试集的问题
1. 训练集中有120种字母集,所以训练集大小是120,测试集中只有20个字母集,为什么测试集大小是41啊,这个41是怎么来的呢 2. 训练集supportset和queryset获取的时候是直接在3856个(120个字母集的3856个字母)字母中随机抽5个字母?每个字母再去抽一个样本吗?这样的话任务与任务之间的区别不就没了吗?我不是不用分成120个字母集了啊,直接放3856个字母去抽样不就行了? 以上的问题不是很懂,哪位大佬可以帮忙解答解答啊
博主你好,我想问一下,在这里从验证集取数据,对模型进行微调,但是我们并没有更新元模型权重。这里的目的是什么呢?只是说单纯的验证一下模型的性能吗? for support_images, support_labels, query_images, query_labels in val_loader: # Get variables support_images = support_images.float().to(dev) support_labels = support_labels.long().to(dev) query_images = query_images.float().to(dev) query_labels = query_labels.long().to(dev) loss, acc = maml_train(model, support_images, support_labels, query_images,...
maml实现中,必须把模型参数摘出来,重写forward,吗?对于复杂一点的模型太麻烦了,因为在qry阶段模型已经经过内循环更新,得到loss也只能对更新过的model用参数更新,没办法直接对初始模型init_model参数更新。 您采用的方法是把模型参数摘出来,并在模型中用了function_forward代替模型本身的forward进行损失计算和摘出来的参数更新,用function_forward代替模型本身forward,从而实现的最后的对模型初始参数的更新。 对于复杂模型,有无简单方法?
关于标签分配问题
img_dirs = random.sample(self.file_list, self.n_way) for label, img_dir in enumerate(img_dirs): print('此次的标签为:',label) img_list = [f for f in glob.glob(img_dir + "**/*.png", recursive=True)] images = random.sample(img_list, self.k_shot + self.q_query) 据我所知,这个火星数据集大约有1000多个类别,但是这里分配真实标签的时候似乎没有什么依据,而是按照索引顺序分配的。关键是,这个标签数值分配似乎是错误的。只能分配0-5
作者你好,请问我在训练自己的数据集时,频繁报错读取不到dataset,具体如下: Traceback (most recent call last): File "E:\torch_maml\train.py", line 23, in train_loader = DataLoader(train_dataset, batch_size=args.task_num, shuffle=True, num_workers=args.num_workers) File "D:\Anaconda3\envs\MAML\lib\site-packages\torch\utils\data\dataloader.py", line 388, in __init__ sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]...