LwF gradient backpropagation
Hi,
I think that in LwF plugin in the penalty computation the gradient doesn't flow correctly in some cases
with torch.no_grad():
if isinstance(self.prev_model, MultiTaskModule):
# output from previous output heads.
y_prev = avalanche_forward(self.prev_model, x, None)
# in a multitask scenario we need to compute the output
# from all the heads, so we need to call forward again.
# TODO: can we avoid this?
y_curr = avalanche_forward(curr_model, x, None)
else: # no task labels
y_prev = {"0": self.prev_model(x)}
y_curr = {"0": out}
- it seems to be ok for
else: # no task labelsbecauseoutis already computed but if I would do something likey_curr = {"0": curr_model(x)}then it wouldn't work. - so this is actually the case for
if isinstance(self.prev_model, MultiTaskModule):, wherey_curr['0'].requires_gradwould returnFalse
Please let me know if there is something that I am missing here and actually the computation is correct.
Cheers, Woj
you are correct, this is a bug
I changed that part of code with the following one, but the final accuracy is still lower than expected.
if isinstance(self.prev_model, MultiTaskModule):
# output from previous output heads.
with torch.no_grad():
y_prev = avalanche_forward(self.prev_model, x, None)
# in a multitask scenario we need to compute the output
# from all the heads, so we need to call forward again.
# TODO: can we avoid this?
y_curr = avalanche_forward(curr_model, x, None)
else: # no task labels
with torch.no_grad():
y_prev = {"0": self.prev_model(x)}
y_curr = {"0": out}
Oh right, the change above applies only to MultiTaskModule, so I guess that is still a useful fix. There may be another problem for single-headed models, too.
another thing is whether the self.prev_model should be .train() or .eval() because currently it is in the training mode so it would affect batch norm and dropouts for example and thus knowledge distilled will be different.