Implementing captum with pytorch-lightning
❓ Questions and Help
Hi there, I am a new user to and I am trying to use LayerGradCam in captum to interpret a particular layer in my model.
Part of the problem/complication seems to be that my model and forward method are defined in a pytorch-lightning module.
My pytorch-lightning module is:
class model(pl.LightningModule):
def __init__(self, learning_rate = float):
super().__init__()
self.learning_rate = learning_rate
self.criterion = nn.BCEWithLogitsLoss()
self.cam = LayerGradCam(self.forward, 'model.5')
self.model = nn.Sequential(cnnBlock1(), cnnBlock2(), cnnBlock3(), linearBlock())
def forward(self, x):
return self.model(x)
def train_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
train_loss = self.criterion(y_hat, y)
self.log('train_loss', train_loss)
return train_loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
val_loss = self.criterion(y_hat, y)
self.log('val_loss', val_loss)
return val_loss
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
attr = self.cam.attribute(x)
return attr
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate
return optimizer
However, when I run the test step I am getting the error:
AttributeError: 'str' object has no attribute 'register_forward_hook'
I have two questions then:
- What does this error mean and how do I fix it?
- How do I/what is best practice for implementing captum with pytorch-lightning?
Thanks for your help!
Tagging same question on pytorch forums. https://discuss.pytorch.org/t/implementing-captum-with-pytorch-lightning/129292
hi @ik362 , sorry for my late reply.
I believe pytorch-lightning has nothing to do here. It will work as long as your model have a forward-like interface to pass into Captum.
The issue is caused by the 2nd argument in the following line.
self.cam = LayerGradCam(self.forward, 'model.5')
What is the string model.5? named module? is it defined in your blocks, e.g., linearBlock?
Anyway, the 2nd argument layer should be the module itself, not a name. You can refer to our documentation for details https://github.com/pytorch/captum/blob/4faf1ea49fbff90af92b759c1f763dda1d8be705/captum/attr/_core/layer/grad_cam.py#L64-L67