captum icon indicating copy to clipboard operation
captum copied to clipboard

Detaching attributions from computation graph

Open aarzchan opened this issue 4 years ago • 6 comments

Hi! I'm interested in regularizing the model via some attribution-based loss, as described in a previous post. As a baseline, I would like to train a model without such regularization (i.e., using only task loss), but using the attributions to compute some evaluation metric.

The desired workflow for this baseline is as follows:

  1. Compute the gradient-based attributions (e.g., using IG) for the model, but without keeping the gradients used to obtain the attributions beyond Step 1. That is, all I want is the final attribution vector, before starting with an empty computation graph for Step 2.
  2. Perform a forward pass through the model, use the task labels to compute the loss, then backprop this loss only via gradients computed in Step 2. Crucially, the attributions from Step 1 are not used here in Step 2. In other words, when I perform the backward pass here, I don't want there to be any connection to the computation in Step 1.
  3. Use the attributions obtained from Step 1 to compute the evaluation metric.

So far, I've tried: (a) doing detach() and with torch.no_grad() in Step 1, (b) perform Step 1 on a deepcopy of the model, and (c) removing Step 3. However, the train loss in Step 2 is somehow still being affected by the attribution computation from Step 1.

I'd appreciate any advice on how to resolve this. Thanks!

aarzchan avatar Oct 19 '21 23:10 aarzchan

@aarzchan, do you have a toy model demonstrating the use case ?

IG computes the gradients w.r.t. the input tensors not the weight matrices and since only_inputs =True the gradients will not be accumulated and propagated to other leave nodes according to autograd documentation. In step 2 do you see that the gradients are still there from step 1 ? You could use model.zero_grad() but I'm not sure I understand the problem correctly. If you show us an example on a toy model that's help us to understand the problem better.

NarineK avatar Oct 20 '21 05:10 NarineK

Hi @NarineK, thanks for the reply!

Below is some toy code illustrating the situation:

import torch
from torch import nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients, InputXGradient, GradientShap

def Net(nn.Module):
    def __init__(self):
        self.arch = ...

    def forward(self, input):
        return self.arch(input)

def run_step(input, target):    
    logits = model(input)
    loss = F.cross_entropy(logits, target)
    optimizer.zero_grad()
    loss.backward()
    return loss
    
model = Net()
optimizer = torch.optim.Adam(model.parameters())
attr_func = IntegratedGradients(model)

# Train or eval loop
for (input, target) in batches:
    # Step 1: Compute attributions
    attrs = attr_func.attribute(input, target)

    # Step 2: Compute task loss
    loss = run_step(input, target)

If we are training the model, then I find that the loss is different for IntegratedGradients, InputXGradient, GradientShap, and skipping Step 1. However, if we are evaluating the model, then the loss ends up being the same for IntegratedGradients, InputXGradient, and skipping Step 1, but not GradientShap. Ideally, for both training and evaluation, the loss would be the same for any attribution algorithm or skipping Step 1, since the model should not be updated w.r.t. the attribution computation.

Also, I tried optimizer.zero_grad() and model.zero_grad() before and/or after Step 1, but it didn't fix the issue. After zeroing the gradients, I do see that the gradients are indeed all zero before doing Step 2, yet the train loss is still getting affected.

aarzchan avatar Oct 20 '21 20:10 aarzchan

Thank you for the example, @aarzchan! Do you have Dropout or Batch Norm in the mode ? When you use Dropout, for instance, it is randomly choosing which neurons to drop out if your model runs in train mode. That's why you might be seeing different losses. In the eval mode the randomization is turned off that's why you don't see the same effects. These are some contemplation that I have, I don't know if you are using Dropout or batch norm.

NarineK avatar Oct 21 '21 01:10 NarineK

@NarineK

That's a good point! My model is using dropout, so let me check what happens when I turn off dropout in Step 1.

However, in eval mode, why does GradientShap behave differently from the other attribution algorithms?

aarzchan avatar Oct 21 '21 20:10 aarzchan

GradientShap uses randomization. It selects baseline randomly and in addition to that it also randomly selects data points between input and baseline that's the only big difference compared to IG. Attribution results of GradientShap won't be deterministic if the seeds aren't fixed but it shouldn't effect the loss. If during eval there is no randomization I wonder if model's forward still depends on some seed value.

https://github.com/pytorch/captum/blob/master/captum/attr/_core/gradient_shap.py#L412

NarineK avatar Oct 22 '21 18:10 NarineK

Hi @NarineK, apologies for the delayed follow-up, and thanks for your feedback!

Case 1: When turning off dropout for both Steps 1 and 2, I'm able to get the same train/dev loss for IntegratedGradients, InputXGradient, GradientShap, and skipping Step 1.

Case 2: Meanwhile, if I only turn off dropout for Step 2, then the train/dev loss for GradientShap will be different. As you mentioned, it seems the GradientShap randomization will affect the dropout in Step 2.

Case 3: Interestingly, if I set the random seed at the beginning of Step 2, then the train/dev loss will be the same for IntegratedGradients, InputXGradient, and GradientShap, but not for skipping Step 1.

Since it's not really necessary for me to use GradientShap, I think I'll just skip GradientShap for now and go with the result of Case 2. I'll spend some time later to figure out the issue with Case 3.

aarzchan avatar Oct 27 '21 21:10 aarzchan