CURL
CURL copied to clipboard
kl_loss
Dear @Roll920 ,
Despite the computational cost, KL-divergence is a very interesting criterion for assigning filter importance. However, I'm a bit confused about the KL loss implemented in your code (e.g., lines 128-129 in generate_mask.py). Based on Equation 1 in the paper, kl_loss = torch.mean(torch.sum(softmax(output) * (logsoftmax(output) - logsoftmax(logits)), dim=1)) should be kl_loss = torch.mean(torch.sum(softmax(logits) * (logsoftmax(logits) - logsoftmax(output)), dim=1))?
Thanks in advance,