captum icon indicating copy to clipboard operation
captum copied to clipboard

Unclear error message for RunTimeError in LRP

Open katjahauser opened this issue 4 years ago • 1 comments

Dear developers,

I encountered an error in LRP (and LayerLRP) that is caused by using the same layer (in this case a pooling layer) twice in the model. The error message is not very helpful for debugging, though: "Function ThnnConv2DBackward returned an invalid gradient at index 0 - got [1, 16, 15, 15] but expected shape compatible with [1, 32, 6, 6]". When using DeepLift on the same model, a more helpful error message is provided: "A Module MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) was detected that does not contain some of the input/output attributes that are required for DeepLift computations. This can occur, for example, if your module is being used more than once in the network.Please, ensure that module is being used only once in the network."

The error itself is in both cases fixed by introducing another pooling layer. I would therefore kindly ask, if you could adapt the LRP and LayerLRP error messages accordingly to make debugging easier.

Below, you find a minimal working example to reproduce the error messages.

Best wishes, Katja Hauser


import torch.nn as nn
import torch.nn.functional as F
import torch
from captum.attr import LRP, DeepLift


class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.ff = nn.Linear(32*6*6, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.softmax(self.ff(x), dim=1)
        return x


if __name__ == "__main__":
    net = ImageClassifier()
    input = torch.randn(1, 1, 32, 32)
    try:
        lrp = LRP(net)
        attribution = lrp.attribute(input, target=5)
    except RuntimeError as e:
        print("LRP: ", e)
    try:
        dl = DeepLift(net)
        attribution = dl.attribute(input, target=5)
    except RuntimeError as e:
        print("DeepLift: ", e)

katjahauser avatar Feb 10 '22 08:02 katjahauser

Thanks for reporting this issue @katjahauser ! We have added a similar warning to LRP in #911 .

vivekmig avatar Mar 24 '22 18:03 vivekmig