captum icon indicating copy to clipboard operation
captum copied to clipboard

running backward on a loss function which contain explanation

Open ahmadajal opened this issue 4 years ago • 7 comments

Hello,

I am trying to run a backward pass on an objective function that contains an explanation. The goal is to run a manipulate attack on explanations. However, it is strange that the explanation attained by the attribute method from captum does not require gradient even though the input for which the explanation is being computed requires gradient. For example in the following code snippet:

sm = Saliency(vgg_model)
expl = sm.attribute(x_adv, target=17)

The expl tensor is a leaf tensor and doesn't require gradient.

There's further another issue with "LRP" in this application. When computing backward pass, I encounter this error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4096, 1000]], which is output 0 of TBackward, is at version 7; expected version 6 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I would really appreciate it if you could help me with these issues.

Best, -Ahmad

ahmadajal avatar Apr 14 '21 10:04 ahmadajal

Hi @ahmadajal , here is a workaround to obtain explanations that require gradient. We essentially need to override the default gradient function, to add an additional parameter create_graph=True, which enables higher order derivatives.

from captum._utils.common import _run_forward
from typing import Any, Callable, Union, Tuple
from torch import Tensor
import torch

# This is the same as the default compute_gradients
# function in captum._utils.gradient, except
# setting create_graph=True when calling
# torch.autograd.grad
def compute_gradients(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    target_ind = None,
    additional_forward_args: Any = None,
) -> Tuple[Tensor, ...]:
    with torch.autograd.set_grad_enabled(True):
        # runs forward pass
        outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
        assert outputs[0].numel() == 1, (
            "Target not provided when necessary, cannot"
            " take gradient with respect to multiple outputs."
        )
        grads = torch.autograd.grad(torch.unbind(outputs), inputs, create_graph=True)
    return grads

from captum.attr import Saliency
sal = Saliency(model)
sal.gradient_func = compute_gradients
attr = sal.attribute(inp, target=1)

We will look into approaches to expose this option more easily, but this approach should work in the meantime, with attr requiring gradients here.

We will also look into the LRP issue further, would you be able to provide an example to reproduce this? It seems this may be related to inplace operations, would you also be able to try with replacing any inplace operations (e.g. set inplace to False on ReLUs if applicable) in your model and see if that resolves the issue?

vivekmig avatar Apr 15 '21 16:04 vivekmig

Dear @vivekmig,

Thanks a lot for your answer. Indeed the workaround you suggested worked and now the attributions require gradients.

Regarding the LRP issue, here's a small example to reproduce the error. As you said, I also tried to set inplace to False on ReLU activations of the model but that didn't solve the error.

# reproduce the error of LRP
# x: input image, x_target: target image to generate target expl map (both from Imagenet)
# model
vgg_model = torchvision.models.vgg16(pretrained=True)
######
x_index = vgg_model(x).argmax()
x_target_index = vgg_model(x_target).argmax()
lrp = LRP(vgg_model)
org_expl = lrp.attribute(x, target=x_index)
target_expl = lrp.attribute(x_target, target=x_target_index)
###############
delta = torch.zeros_like(x, requires_grad=True) #additive noised
nn.init.normal_(delta, mean=0.0, std=1e-6)
x_adv = x + delta
optimizer = torch.optim.Adam([delta], lr=0.001)

for i in range(20):
    optimizer.zero_grad()
    adv_expl = lrp.attribute(x_adv, target=x_index)
    expl_loss = F.mse_loss(adv_expl, target_expl)
    delta_loss = torch.norm(delta)
    
    loss = 10*expl_loss + 0.01*delta_loss
    print("total: {}, expl: {}, delta: {}".format(loss, expl_loss, delta_loss))
    loss.backward()
    optimizer.step()
    x_adv = x+delta

The error happens on the first backward pass and the message is:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4096, 1000]], which is output 0 of TBackward, is at version 18; expected version 17 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I hope the example helps you to reproduce the error. Thanks for your time.

-Ahmad

ahmadajal avatar Apr 16 '21 16:04 ahmadajal

Hi @ahmadajal , here is a workaround to obtain explanations that require gradient. We essentially need to override the default gradient function, to add an additional parameter create_graph=True, which enables higher order derivatives.

from captum._utils.common import _run_forward
from typing import Any, Callable, Union, Tuple
from torch import Tensor
import torch

# This is the same as the default compute_gradients
# function in captum._utils.gradient, except
# setting create_graph=True when calling
# torch.autograd.grad
def compute_gradients(
    forward_fn: Callable,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    target_ind = None,
    additional_forward_args: Any = None,
) -> Tuple[Tensor, ...]:
    with torch.autograd.set_grad_enabled(True):
        # runs forward pass
        outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
        assert outputs[0].numel() == 1, (
            "Target not provided when necessary, cannot"
            " take gradient with respect to multiple outputs."
        )
        grads = torch.autograd.grad(torch.unbind(outputs), inputs, create_graph=True)
    return grads

from captum.attr import Saliency
sal = Saliency(model)
sal.gradient_func = compute_gradients
attr = sal.attribute(inp, target=1)

We will look into approaches to expose this option more easily, but this approach should work in the meantime, with attr requiring gradients here.

We will also look into the LRP issue further, would you be able to provide an example to reproduce this? It seems this may be related to inplace operations, would you also be able to try with replacing any inplace operations (e.g. set inplace to False on ReLUs if applicable) in your model and see if that resolves the issue?

Hello again,

Regarding this workaround, even adding the parameter create_graph=True in the compute_gradients function would not solve this issue for smooth explanations, e.g, smooth-grad. I realize that this might be because of line 337 in the attribute function of the NoiseTunnel class, where you have with torch.no_grad(): and then you compute the mean of the attributions (and correct me please if I am wrong). Therefore, I think replacing that line with e.g, with torch.enable_grad(): should solve the issue and give explanations that require gradient. Would this make sense for you?

Thanks, -Ahmad

ahmadajal avatar Apr 19 '21 10:04 ahmadajal

@ahmadajal, @vivekmig, is this still an open issue that needs to be addressed ?

NarineK avatar Jul 09 '21 05:07 NarineK

Hi @NarineK ,

I couldn't figure out what is the problem with LRP when doing backward pass. I do not understand which in-place operation is causing such an error. And I also set inplace=False for ReLu activations of the network but it didn't solve the issue.

ahmadajal avatar Jul 09 '21 15:07 ahmadajal

Do we have any solution now? I also have exactly the same error when trying backward loss containing LRP explanations.

tangli0305 avatar Sep 15 '22 16:09 tangli0305

I am also looking for a solution to the problem with LRP.

b-turan avatar Nov 30 '23 13:11 b-turan