Could not prune GVP model
I have a model with GVP. The GVP model could be described below: `from gvp.models import LayerNorm, GVP, GVPConvLayer from torch import nn import torch import torch.nn.functional as F import torch_geometric
class GVPTransCond(torch.nn.Module): def init(self): super().init() self.node_in_dim = (8, 4) self.node_h_dim = (512, 128) self.edge_in_dim = (32, 1) self.edge_h_dim = (128, 1) self.num_layers = 4 drop_rate = 0 activations = (F.relu, None) self.W_v = torch.nn.Sequential( LayerNorm(self.node_in_dim), GVP(self.node_in_dim, self.node_h_dim, activations=(None, None), vector_gate=True) ) self.W_e = torch.nn.Sequential( LayerNorm(self.edge_in_dim), GVP(self.edge_in_dim, self.edge_h_dim, activations=(None, None), vector_gate=True) )
# Encoder layers (supports multiple conformations)
self.encoder_layers = nn.ModuleList(
GVPConvLayer(self.node_h_dim, self.edge_h_dim,
activations=activations, vector_gate=True,
drop_rate=drop_rate)
for _ in range(self.num_layers))
# Output
self.W_out = GVP(self.node_h_dim, (self.node_h_dim[0], 0), activations=(None, None))
def forward(self,zz, nv,es,ev,eee):
c = torch.zeros_like(zz)
iss = torch.cat([zz, c], -1)
h_V = (iss, nv)
h_E = (es, ev)
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.encoder_layers:
h_V = layer(h_V, eee, h_E)
retn = self.W_out(h_V).reshape(1, 1024, -1)
return retn
m = GVPTransCond() print(m(torch.randn(1024,4),torch.randn(1024,4,3),torch.randn(10240,32),torch.randn(10240,1,3),torch.ones(2,10240,dtype=torch.long)))
import torch_pruning as tp imp = tp.importance.RandomImportance() ignored_layers = [] batch_input = { 'zz':torch.randn(1024,4), 'nv':torch.randn(1024,4,3), 'es':torch.randn(10240,32), 'ev':torch.randn(10240,1,3), 'eee':torch.ones(2,10240,dtype=torch.long), }
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required. m, batch_input, importance=imp, pruning_ratio=0.5, ignored_layers=None, round_to=8, ) base_macs, base_nparams = tp.utils.count_ops_and_params(m, batch_input) pruner.step() macs, nparams = tp.utils.count_ops_and_params(m, batch_input) print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")`
When I run the pruning , it shows these errors:
`Warning! No positional inputs found for a module, assuming batch size is 1.
torch.Size([1024, 448]) torch.Size([1024, 128, 3])
Traceback (most recent call last):
File "test.py", line 76, in
File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/gvp/init.py", line 361, in forward dh = self.conv(x, edge_index, edge_attr) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/gvp/init.py", line 267, in forward message = self.propagate(edge_index, ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "gvp_GVPConv_propagate_vwj7whf1.py", line 248, in propagate out = self.message( ^^^^^^^^^^^^^ File "site-packages/gvp/init.py", line 276, in message message = self.message_func(message) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/container.py", line 219, in forward input = module(input) ^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/gvp/init.py", line 122, in forward s = self.ws(torch.cat([s, vn], -1)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "site-packages/torch/nn/modules/linear.py", line 117, in forward return F.linear(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: mat1 and mat2 shapes cannot be multiplied (10240x1089 and 1217x448)`
Did anyone met this problem? When II use GroupNormalImportance it will cause index error:
index 698 is out of bounds for dimension 0 with size 257
How to solve it?