Error while trying to prune Lenet-5 architecture with MNIST dataset
DG.build_dependency(model, example_inputs=th.randn(1,1,28,28)) File "/usr/local/lib/python3.6/dist-packages/torch_pruning/dependency.py", line 309, in build_dependency self.update_index() File "/usr/local/lib/python3.6/dist-packages/torch_pruning/dependency.py", line 315, in update_index self._set_fc_index_transform( node ) File "/usr/local/lib/python3.6/dist-packages/torch_pruning/dependency.py", line 437, in _set_fc_index_transform feature_channels = _get_in_node_out_channels(fc_node.inputs[0]) IndexError: list index out of range
Have you solved it, I also met the error.
I have got same problem on version 0.2.7. I just use torch_pruning==0.2.6.
I am facing the same issue with another model. Tried both 0.2.6 and 0.2.7 versions. Does anyone know how to fix it?
Hi @MPiotr, does this problem still exist with the latest commit?
This script works for me:
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import torch
import torch.nn as nn
import torch_pruning as tp
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.view(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
model = Model()
# pruning according to L1 Norm
strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()
# build layer dependency for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,1,28,28))
# get a pruning plan according to the dependency graph. idxs is the indices of pruned filters.
pruning_idxs = [0, 2, 6] #strategy(model.conv1.weight, amount=0.4) # or manually selected [0, 2, 6]
pruning_plan = DG.get_pruning_plan( model.fc2, tp.prune_linear_out_channel, idxs=pruning_idxs )
print(pruning_plan)
# execute this plan (prune the model)
if DG.check_pruning_plan(pruning_plan):
pruning_plan.exec()
Hi, @VainF, thank you for reply. Eventually, that was my mistake. Some gradients of the model were frozen, that broke the graph builder