Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Could not prune GVP model

Open LHucass opened this issue 10 months ago • 0 comments

I have a model with GVP. The GVP model could be described below: `from gvp.models import LayerNorm, GVP, GVPConvLayer from torch import nn import torch import torch.nn.functional as F import torch_geometric

class GVPTransCond(torch.nn.Module): def init(self): super().init() self.node_in_dim = (8, 4) self.node_h_dim = (512, 128) self.edge_in_dim = (32, 1) self.edge_h_dim = (128, 1) self.num_layers = 4 drop_rate = 0 activations = (F.relu, None) self.W_v = torch.nn.Sequential( LayerNorm(self.node_in_dim), GVP(self.node_in_dim, self.node_h_dim, activations=(None, None), vector_gate=True) ) self.W_e = torch.nn.Sequential( LayerNorm(self.edge_in_dim), GVP(self.edge_in_dim, self.edge_h_dim, activations=(None, None), vector_gate=True) )

    # Encoder layers (supports multiple conformations)
    self.encoder_layers = nn.ModuleList(
        GVPConvLayer(self.node_h_dim, self.edge_h_dim,
                     activations=activations, vector_gate=True,
                     drop_rate=drop_rate)
        for _ in range(self.num_layers))

    # Output
    self.W_out = GVP(self.node_h_dim, (self.node_h_dim[0], 0), activations=(None, None))

def forward(self,zz, nv,es,ev,eee):
    c = torch.zeros_like(zz)
    iss = torch.cat([zz, c], -1)
    h_V = (iss, nv)
    h_E = (es, ev)
    h_V = self.W_v(h_V)
    h_E = self.W_e(h_E)
    for layer in self.encoder_layers:
        h_V = layer(h_V, eee, h_E)
    
    retn = self.W_out(h_V).reshape(1, 1024, -1)
    return retn

m = GVPTransCond() print(m(torch.randn(1024,4),torch.randn(1024,4,3),torch.randn(10240,32),torch.randn(10240,1,3),torch.ones(2,10240,dtype=torch.long)))

import torch_pruning as tp imp = tp.importance.RandomImportance() ignored_layers = [] batch_input = { 'zz':torch.randn(1024,4), 'nv':torch.randn(1024,4,3), 'es':torch.randn(10240,32), 'ev':torch.randn(10240,1,3), 'eee':torch.ones(2,10240,dtype=torch.long), }

pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required. m, batch_input, importance=imp, pruning_ratio=0.5, ignored_layers=None, round_to=8, ) base_macs, base_nparams = tp.utils.count_ops_and_params(m, batch_input) pruner.step() macs, nparams = tp.utils.count_ops_and_params(m, batch_input) print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")`

When I run the pruning , it shows these errors: `Warning! No positional inputs found for a module, assuming batch size is 1. torch.Size([1024, 448]) torch.Size([1024, 128, 3]) Traceback (most recent call last): File "test.py", line 76, in print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch_pruning/utils/op_counter.py", line 33, in count_ops_and_params _ = flops_model(**example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "test.py", line 47, in forward

File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/gvp/init.py", line 361, in forward dh = self.conv(x, edge_index, edge_attr) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/gvp/init.py", line 267, in forward message = self.propagate(edge_index, ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "gvp_GVPConv_propagate_vwj7whf1.py", line 248, in propagate out = self.message( ^^^^^^^^^^^^^ File "site-packages/gvp/init.py", line 276, in message message = self.message_func(message) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/container.py", line 219, in forward input = module(input) ^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/gvp/init.py", line 122, in forward s = self.ws(torch.cat([s, vn], -1)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/linear.py", line 117, in forward return F.linear(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: mat1 and mat2 shapes cannot be multiplied (10240x1089 and 1217x448)`

Did anyone met this problem? When II use GroupNormalImportance it will cause index error: index 698 is out of bounds for dimension 0 with size 257

How to solve it?

LHucass avatar Mar 23 '25 00:03 LHucass