Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

ResNet20 Cifar10 weird pruning result!

Open ghimiredhikura opened this issue 3 years ago • 9 comments

Hi @VainF,

Thank you for the nice work. It is a clean and very helpful pruning framework, especially for unstructured pruning. However, I have trouble pruning the ResNet20 model with the Cifar10 dataset. The problem is even if I just want to prune Conv layers in the single basic block it has a chain effect of pruning the entire network. I am using Cifar10 version of the slim ResNet20 model which is different from the one in your example.

ResNet model to prune: resnet_cifar10.zip

Pruning Script (Pruning applied only in 9th ResNetBasicblock)

from resnet_cifar10 import resnet20
import resnet_cifar10 as resnet
import torch_pruning as tp

def prune_model(model):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
    def prune_conv(conv, amount=0.2):
        strategy = tp.strategy.L1Strategy()
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        plan.exec()
    
    blk_id = 0
    for m in model.modules():
        if isinstance( m, resnet.ResNetBasicblock ):
            if blk_id == 8:
                prune_conv( m.conv_a, 0.3 )
                prune_conv( m.conv_b, 0.2 )
            blk_id+=1
    return model    

def main():
    model = resnet20(num_classes=10)    
    print("Before", model)
    prune_model(model)
    print("After", model)
if __name__=='__main__':
    main()

Before Pruning:

CifarResNet(
  (conv_1_3x3): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (stage_1): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stage_2): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): DownsampleA(
        (avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
      )
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stage_3): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): DownsampleA(
        (avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
      )
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
  (classifier): Linear(in_features=64, out_features=10, bias=True)
)

After Pruning only 9th ResNetBasicblock

 CifarResNet(
  (conv_1_3x3): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (stage_1): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stage_2): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(6, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): DownsampleA(
        (avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
      )
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(12, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(12, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stage_3): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(12, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): DownsampleA(
        (avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
      )
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(24, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(24, 45, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(45, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(45, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
  (classifier): Linear(in_features=24, out_features=10, bias=True)
)

Issue You can see the network after pruning filters in all blocks got removed like a chain effect. I am not sure if this is expected. Any help will be greatly appriciated.

Thanks.

ghimiredhikura avatar May 18 '22 04:05 ghimiredhikura

This is an expected outcome. This is described in the documentation. The library automatically handles the pruning of residual connections. Because you pruned a layer just before the skip connection it needs to be pruned from all layers in that block. To avoid this you can simply not prune conv_b. Pruning conv_a will still reduce the parameters and flops of b as all those input channels will be reduced.

codestar12 avatar May 20 '22 19:05 codestar12

Hi @codestar12,

Let's say we prune stage_3, 2th ResNetBasicblock, conv_b which is the last Conv layer in the whole network. Doing so, it is pruning every Conv layer up to the very first layer conv_1_3x3 as well. This happens only with the network I posted above (Cifar10), but not with this. As you said due to skip connection, pruning all Conv layers within the block is expected, but I am confused because it is pruning in another block as well.

ghimiredhikura avatar May 20 '22 20:05 ghimiredhikura

Unfortunately, this package has the problem which you suffered. I already had a similar issue before. If you prune a neural network having residual connections, this package excessively prunes the layer of the neural network. The deeper, the more channels of the layer are pruned.

planemanner avatar Jun 21 '22 01:06 planemanner

Hi @VainF, I see you made some updates. By any chance did you check the issue I explained above? Thanks.

ghimiredhikura avatar Aug 10 '22 08:08 ghimiredhikura

Hello @ghimiredhikura, I think the result is desired because different blocks are residually connected. They must be pruned together to avoid mismatched "Element-wise Add" in residual connections. Could you print the pruning plan in your example as it can provide more information about the pruning chain.

VainF avatar Aug 10 '22 08:08 VainF

BTW, if you do not like this behaviour, please replace the AvgPool2d (#L11 of your resnet_cifar10.py) with a standard Conv2d. The chained pruning will stop at this conv layer.

VainF avatar Aug 10 '22 08:08 VainF

Hello @ghimiredhikura, I think the result is desired because different blocks are residually connected. They must be pruned together to avoid mismatched "Element-wise Add" in residual connections. Could you print the pruning plan in your example as it can provide more information about the pruning chain.

# Pruning Code

import torch

from models.resnet_cifar import resnet20
import models.resnet_cifar as resnet
import torch_pruning_tool_v1.torch_pruning as tp

def prune_model(model):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
    def prune_conv(conv, amount=0.2):
        strategy = tp.strategy.L1Strategy()
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv_out_channel, pruning_index)
        print(plan)
        plan.exec()
    
    blk_id = 0
    for m in model.modules():
        if isinstance( m, resnet.ResNetBasicblock ):
            if blk_id == 8:
                #prune_conv( m.conv_a, 0.3 )
                prune_conv( m.conv_b, 0.2 )
            blk_id+=1
    return model    

def main():
    model = resnet20(num_classes=10)    
    print("Before", model)
    prune_model(model)
    print("After", model)
if __name__=='__main__':
    main()

# Original net 

CifarResNet : Depth : 20 , Layers for each block : 3
CifarResNet(
  (conv_1_3x3): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (stage_1): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stage_2): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): DownsampleA(
        (avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
      )
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stage_3): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): DownsampleA(
        (avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
      )
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
  (classifier): Linear(in_features=64, out_features=10, bias=True)
)

--------------------------------
          Pruning Plan
--------------------------------
User pruning:
[ [DEP] ConvOutChannelPruner on stage_3.2.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => ConvOutChannelPruner on stage_3.2.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]

Coupled pruning:
[ [DEP] ConvOutChannelPruner on stage_3.2.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => BatchnormPruner on stage_3.2.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 48}]
[ [DEP] BatchnormPruner on stage_3.2.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0) => ElementWiseOpPruner on _ElementWiseOp(ViewBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ViewBackward0) => LinearInChannelPruner on classifier (Linear(in_features=64, out_features=10, bias=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 240}]
[ [DEP] LinearInChannelPruner on classifier (Linear(in_features=64, out_features=10, bias=True)) => ElementWiseOpPruner on _ElementWiseOp(TBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_3.2.conv_a (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_3.1.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 48}]    
[ [DEP] BatchnormPruner on stage_3.1.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_3.1.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_3.1.conv_a (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ConcatPruner on _ConcatOp([0, 32, 64]), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_3.0.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 48}]    
[ [DEP] BatchnormPruner on stage_3.0.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_3.0.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
[ [DEP] ConcatPruner on _ConcatOp([0, 32, 64]) => ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ConcatPruner on _ConcatOp([0, 32, 64]) => ElementWiseOpPruner on _ElementWiseOp(MulBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_3.0.conv_a (Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 6912}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_2.2.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 24}]
[ [DEP] BatchnormPruner on stage_2.2.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_2.2.conv_b (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_2.2.conv_a (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_2.1.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 24}]
[ [DEP] BatchnormPruner on stage_2.1.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_2.1.conv_b (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_2.1.conv_a (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ConcatPruner on _ConcatOp([0, 16, 32]), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_2.0.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 24}]
[ [DEP] BatchnormPruner on stage_2.0.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_2.0.conv_b (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ConcatPruner on _ConcatOp([0, 16, 32]) => ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ConcatPruner on _ConcatOp([0, 16, 32]) => ElementWiseOpPruner on _ElementWiseOp(MulBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_2.0.conv_a (Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 1728}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_1.2.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 12}]
[ [DEP] BatchnormPruner on stage_1.2.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_1.2.conv_b (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]      
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_1.2.conv_a (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_1.1.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 12}]
[ [DEP] BatchnormPruner on stage_1.1.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_1.1.conv_b (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]      
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_1.1.conv_a (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_1.0.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 12}]
[ [DEP] BatchnormPruner on stage_1.0.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_1.0.conv_b (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]      
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => BatchnormPruner on bn_1 (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 12}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_1.0.conv_a (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] BatchnormPruner on bn_1 (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on conv_1_3x3 (Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 162}]

Metric Sum: {'#params': 100890}
--------------------------------

# Output net 
 CifarResNet(
  (conv_1_3x3): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (stage_1): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(16, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stage_2): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(10, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): DownsampleA(
        (avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
      )
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(20, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(20, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(32, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (stage_3): Sequential(
    (0): ResNetBasicblock(
      (conv_a): Conv2d(20, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): DownsampleA(
        (avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
      )
    )
    (1): ResNetBasicblock(
      (conv_a): Conv2d(40, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ResNetBasicblock(
      (conv_a): Conv2d(40, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b): Conv2d(64, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
  (classifier): Linear(in_features=40, out_features=10, bias=True)
)

ghimiredhikura avatar Aug 10 '22 09:08 ghimiredhikura

BTW, if you do not like this behaviour, please replace the AvgPool2d (#L11 of your resnet_cifar10.py) with a standard Conv2d. The chained pruning will stop at this conv layer.

Thanks. I will try this one.

ghimiredhikura avatar Aug 10 '22 09:08 ghimiredhikura

BTW, if you do not like this behaviour, please replace the AvgPool2d (#L11 of your resnet_cifar10.py) with a standard Conv2d. The chained pruning will stop at this conv layer.

Yes, using standard Conv2d it works as expected. Thank you for the good work.

ghimiredhikura avatar Aug 19 '22 08:08 ghimiredhikura