Torch-Pruning
Torch-Pruning copied to clipboard
Can't Save Pruned Model: " TypeError: can't pickle SigmoidBackward0 objects"
Example:
class MyMLP(torch.nn.Module):
def __init__(self, out_len):
super().__init__()
self.model = torch.nn.Sequential(torch.nn.Linear(1000,out_len), torch.nn.Sigmoid())
def forward(self, x):
return self.model(x)
class MyModel(torch.nn.Module):
def __init__(self, out_len):
super().__init__()
self.cnn = torchvision.models.quantization.resnet18()
self.mlp = MyMLP(2)
def forward(self, x):
return self.mlp(self.cnn(x))
def prune(self, amount):
self.DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 224, 224))
for m in model.modules():
if isinstance(m, torchvision.models.quantization.resnet.QuantizableBasicBlock):
self.prune_conv(m.conv1, amount)
self.prune_conv(m.conv2, amount)
def prune_conv(self, conv, amount):
strategy = tp.strategy.L1Strategy()
pruning_index = strategy(conv.weight, amount=amount)
plan = self.DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
plan.exec()
if __name__ == '__main__':
model = MyModel(2)
model(torch.rand((1,3,224,224)))
model.prune(0.2)
torch.save(model, 'model.pth')
Error:
TypeError: can't pickle SigmoidBackward0 objects
Any ideas on how to fix this? Thanks!