Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Can't Save Pruned Model: " TypeError: can't pickle SigmoidBackward0 objects"

Open jonboypy opened this issue 3 years ago • 0 comments

Example:


class MyMLP(torch.nn.Module):

  def __init__(self, out_len):
    super().__init__()
    self.model = torch.nn.Sequential(torch.nn.Linear(1000,out_len), torch.nn.Sigmoid())

  def forward(self, x):
    return self.model(x)


class MyModel(torch.nn.Module):

  def __init__(self, out_len):
    super().__init__()
    self.cnn = torchvision.models.quantization.resnet18()
    self.mlp = MyMLP(2)

  def forward(self, x):
    return self.mlp(self.cnn(x))

  def prune(self, amount):
    self.DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 224, 224))
    for m in model.modules():
      if isinstance(m, torchvision.models.quantization.resnet.QuantizableBasicBlock):
        self.prune_conv(m.conv1, amount)
        self.prune_conv(m.conv2, amount)

  def prune_conv(self, conv, amount):
    strategy = tp.strategy.L1Strategy()
    pruning_index = strategy(conv.weight, amount=amount)
    plan = self.DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
    plan.exec()


if __name__ == '__main__':
    model = MyModel(2)
    model(torch.rand((1,3,224,224)))
    model.prune(0.2)
    torch.save(model, 'model.pth')

Error:

TypeError: can't pickle SigmoidBackward0 objects

Any ideas on how to fix this? Thanks!

jonboypy avatar May 22 '22 19:05 jonboypy