ProDA icon indicating copy to clipboard operation
ProDA copied to clipboard

The kl_div loss of self distillation

Open luyvlei opened this issue 3 years ago • 1 comments

The following code calculate the kl_div loss of teacher from stage 1 and the student model. But the student didn't calculate log_softmax. Is this a mistake?

    student = F.softmax(target_out['out'], dim=1)
    with torch.no_grad():
        teacher_out = self.teacher_DP(target_imageS)
        teacher_out['out'] = F.interpolate(teacher_out['out'], size=threshold_arg.shape[2:], mode='bilinear', align_corners=True)
        teacher = F.softmax(teacher_out['out'], dim=1)

    loss_kd = F.kl_div(student, teacher, reduction='none')
    mask = (teacher != 250).float()
    loss_kd = (loss_kd * mask).sum() / mask.sum()
    loss = loss + self.opt.distillation * loss_kd   

luyvlei avatar Mar 08 '22 07:03 luyvlei

Yeah, should be log_softmax

panzhang0104 avatar Apr 01 '22 02:04 panzhang0104