How to compute second-order hessian for custom module?
Hi, I want to compute the second-order Hessian matrix for my custom nn.Module, which aims to replace the nn.Conv2D by CP decomposition, and I get this warning:
/data/miniforge3/envs/fairmae/lib/python3.9/site-packages/backpack/custom_module/graph_utils.py:86: UserWarning: Encountered node that may break second-order extensions: op=get_attr, target=V.1. If you encounter this problem, please open an issue at https://github.com/f-dangel/backpack/issues.
The architecture of my module is defined below:
class CPConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, rank, stride=1, padding=0, bias=True):
super(CPConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size) if isinstance(
kernel_size, int) else kernel_size
self.rank = rank
self.stride = stride
self.padding = padding
self.bias = bias
# U: (rank,)
# V: [(out_channels, rank), (in_channels, rank), (kernel_size, rank), (kernel_size, rank)]
self.U = nn.Parameter(torch.randn(rank))
self.V = nn.ParameterList([
nn.Parameter(torch.randn(out_channels, rank)),
nn.Parameter(torch.randn(in_channels, rank)),
nn.Parameter(torch.randn(kernel_size[0], rank)),
nn.Parameter(torch.randn(kernel_size[1], rank))
])
if bias:
self.b = nn.Parameter(torch.randn(out_channels))
else:
self.register_parameter('b', None)
def forward(self, x):
W = torch.einsum('r,or,ir,kr,lr->oikl',
self.U, self.V[0], self.V[1], self.V[2], self.V[3])
return F.conv2d(x, W, self.b, self.stride, self.padding)
Could you tell me how to compute the diag_h for the self.U? Thank you!
Hi,
thanks for the clear explanation. We have a tutorial in the docs which explains how to implement second-order extensions for new layers in BackPACK (see here). The tutorial explains how to support the GGN diagonal, which is slightly easier to implement than the Hessian diagonal.
Could you try following the tutorial and implement support for the GGN diagonal first? I can then help you to generalize it to the Hessian diagonal.
Best, Felix
Thanks for your reply! I will try it :)