IMDiffusion icon indicating copy to clipboard operation
IMDiffusion copied to clipboard

训练过程中的掩码策略问题

Open songxueh opened this issue 11 months ago • 0 comments

作者您好,您在论文里提出了栅格掩码,但是在代码里,扩散模型在训练过程中的掩码好像是随机掩码,在验证和测试时使用的是栅格掩码,请问到底使用什么掩码策略啊,代码如下: def forward(self, batch, is_train=1): ( observed_data, observed_mask, observed_tp, gt_mask, for_pattern_mask, _, strategy_type ) = self.process_data(batch) # print("observed data shape is") # print(observed_data.shape) # print("observed mask shape is") # print(observed_mask.shape) # print("observed tp is") # print(observed_tp) # 强制使用0作为cond_mask self.target_strategy = "random" if is_train == 0: cond_mask = gt_mask elif self.target_strategy != "random": cond_mask = self.get_hist_mask( observed_mask, for_pattern_mask=for_pattern_mask ) else: cond_mask = self.get_randmask(observed_mask,ratio=self.ratio) # # cond_mask = torch.zeros_like(observed_mask) # cond_mask = self.get_random_mask(observed_mask) # side_info = self.get_side_info(observed_tp, cond_mask)

    loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid

    return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train, strategy_type = strategy_type)

songxueh avatar May 07 '25 11:05 songxueh