backpack icon indicating copy to clipboard operation
backpack copied to clipboard

How to compute second-order hessian for custom module?

Open XuZikang opened this issue 1 year ago • 2 comments

Hi, I want to compute the second-order Hessian matrix for my custom nn.Module, which aims to replace the nn.Conv2D by CP decomposition, and I get this warning:

/data/miniforge3/envs/fairmae/lib/python3.9/site-packages/backpack/custom_module/graph_utils.py:86: UserWarning: Encountered node that may break second-order extensions: op=get_attr, target=V.1. If you encounter this problem, please open an issue at https://github.com/f-dangel/backpack/issues.

The architecture of my module is defined below:

class CPConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, rank, stride=1, padding=0, bias=True):
        super(CPConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = (kernel_size, kernel_size) if isinstance(
            kernel_size, int) else kernel_size
        self.rank = rank
        self.stride = stride
        self.padding = padding
        self.bias = bias

        # U: (rank,)
        # V: [(out_channels, rank), (in_channels, rank), (kernel_size, rank), (kernel_size, rank)]

        self.U = nn.Parameter(torch.randn(rank))
        self.V = nn.ParameterList([
            nn.Parameter(torch.randn(out_channels, rank)),
            nn.Parameter(torch.randn(in_channels, rank)),
            nn.Parameter(torch.randn(kernel_size[0], rank)),
            nn.Parameter(torch.randn(kernel_size[1], rank))
        ])

        if bias:
            self.b = nn.Parameter(torch.randn(out_channels))
        else:
            self.register_parameter('b', None)

    def forward(self, x):
        W = torch.einsum('r,or,ir,kr,lr->oikl',
                         self.U, self.V[0], self.V[1], self.V[2], self.V[3])

        return F.conv2d(x, W, self.b, self.stride, self.padding)

Could you tell me how to compute the diag_h for the self.U? Thank you!

XuZikang avatar Nov 18 '24 08:11 XuZikang

Hi,

thanks for the clear explanation. We have a tutorial in the docs which explains how to implement second-order extensions for new layers in BackPACK (see here). The tutorial explains how to support the GGN diagonal, which is slightly easier to implement than the Hessian diagonal.

Could you try following the tutorial and implement support for the GGN diagonal first? I can then help you to generalize it to the Hessian diagonal.

Best, Felix

f-dangel avatar Nov 18 '24 14:11 f-dangel

Thanks for your reply! I will try it :)

XuZikang avatar Nov 19 '24 02:11 XuZikang