ProDA
ProDA copied to clipboard
The kl_div loss of self distillation
The following code calculate the kl_div loss of teacher from stage 1 and the student model. But the student didn't calculate log_softmax. Is this a mistake?
student = F.softmax(target_out['out'], dim=1)
with torch.no_grad():
teacher_out = self.teacher_DP(target_imageS)
teacher_out['out'] = F.interpolate(teacher_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
teacher = F.softmax(teacher_out['out'], dim=1)
loss_kd = F.kl_div(student, teacher, reduction='none')
mask = (teacher != 250).float()
loss_kd = (loss_kd * mask).sum() / mask.sum()
loss = loss + self.opt.distillation * loss_kd
Yeah, should be log_softmax