MDD
MDD copied to clipboard
Total loss changes to Nan somehow
Hi, thanks for sharing your nice code. I found that the total loss changes to Nan somehow and the accuracy on the target domain would drop to zero at the same time.
I have found that logloss_tgt may cause this problem because the logloss_tgt is easy to increasue to infinite when its input reaches to zero.
Based on that, I have added a small value to avoid the log value expansion by
logloss_tgt = torch.log(torch.clamp(1 - F.softmax(outputs_adv.narrow(0, labels_source.size(0), inputs.size(0) - labels_source.size(0)), dim = 1), min=1e-15))