ResNet20 Cifar10 weird pruning result!
Hi @VainF,
Thank you for the nice work. It is a clean and very helpful pruning framework, especially for unstructured pruning. However, I have trouble pruning the ResNet20 model with the Cifar10 dataset. The problem is even if I just want to prune Conv layers in the single basic block it has a chain effect of pruning the entire network. I am using Cifar10 version of the slim ResNet20 model which is different from the one in your example.
ResNet model to prune: resnet_cifar10.zip
Pruning Script (Pruning applied only in 9th ResNetBasicblock)
from resnet_cifar10 import resnet20
import resnet_cifar10 as resnet
import torch_pruning as tp
def prune_model(model):
model.cpu()
DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
def prune_conv(conv, amount=0.2):
strategy = tp.strategy.L1Strategy()
pruning_index = strategy(conv.weight, amount=amount)
plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
plan.exec()
blk_id = 0
for m in model.modules():
if isinstance( m, resnet.ResNetBasicblock ):
if blk_id == 8:
prune_conv( m.conv_a, 0.3 )
prune_conv( m.conv_b, 0.2 )
blk_id+=1
return model
def main():
model = resnet20(num_classes=10)
print("Before", model)
prune_model(model)
print("After", model)
if __name__=='__main__':
main()
Before Pruning:
CifarResNet(
(conv_1_3x3): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(stage_1): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(stage_2): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): DownsampleA(
(avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(stage_3): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): DownsampleA(
(avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
(classifier): Linear(in_features=64, out_features=10, bias=True)
)
After Pruning only 9th ResNetBasicblock
CifarResNet(
(conv_1_3x3): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(stage_1): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(stage_2): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(6, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): DownsampleA(
(avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(12, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(12, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(stage_3): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(12, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): DownsampleA(
(avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(24, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(24, 45, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(45, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(45, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
(classifier): Linear(in_features=24, out_features=10, bias=True)
)
Issue You can see the network after pruning filters in all blocks got removed like a chain effect. I am not sure if this is expected. Any help will be greatly appriciated.
Thanks.
This is an expected outcome. This is described in the documentation. The library automatically handles the pruning of residual connections. Because you pruned a layer just before the skip connection it needs to be pruned from all layers in that block. To avoid this you can simply not prune conv_b. Pruning conv_a will still reduce the parameters and flops of b as all those input channels will be reduced.
Hi @codestar12,
Let's say we prune stage_3, 2th ResNetBasicblock, conv_b which is the last Conv layer in the whole network. Doing so, it is pruning every Conv layer up to the very first layer conv_1_3x3 as well. This happens only with the network I posted above (Cifar10), but not with this. As you said due to skip connection, pruning all Conv layers within the block is expected, but I am confused because it is pruning in another block as well.
Unfortunately, this package has the problem which you suffered. I already had a similar issue before. If you prune a neural network having residual connections, this package excessively prunes the layer of the neural network. The deeper, the more channels of the layer are pruned.
Hi @VainF, I see you made some updates. By any chance did you check the issue I explained above? Thanks.
Hello @ghimiredhikura, I think the result is desired because different blocks are residually connected. They must be pruned together to avoid mismatched "Element-wise Add" in residual connections. Could you print the pruning plan in your example as it can provide more information about the pruning chain.
BTW, if you do not like this behaviour, please replace the AvgPool2d (#L11 of your resnet_cifar10.py) with a standard Conv2d. The chained pruning will stop at this conv layer.
Hello @ghimiredhikura, I think the result is desired because different blocks are residually connected. They must be pruned together to avoid mismatched "Element-wise Add" in residual connections. Could you print the pruning plan in your example as it can provide more information about the pruning chain.
# Pruning Code
import torch
from models.resnet_cifar import resnet20
import models.resnet_cifar as resnet
import torch_pruning_tool_v1.torch_pruning as tp
def prune_model(model):
model.cpu()
DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
def prune_conv(conv, amount=0.2):
strategy = tp.strategy.L1Strategy()
pruning_index = strategy(conv.weight, amount=amount)
plan = DG.get_pruning_plan(conv, tp.prune_conv_out_channel, pruning_index)
print(plan)
plan.exec()
blk_id = 0
for m in model.modules():
if isinstance( m, resnet.ResNetBasicblock ):
if blk_id == 8:
#prune_conv( m.conv_a, 0.3 )
prune_conv( m.conv_b, 0.2 )
blk_id+=1
return model
def main():
model = resnet20(num_classes=10)
print("Before", model)
prune_model(model)
print("After", model)
if __name__=='__main__':
main()
# Original net
CifarResNet : Depth : 20 , Layers for each block : 3
CifarResNet(
(conv_1_3x3): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(stage_1): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(stage_2): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): DownsampleA(
(avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(stage_3): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): DownsampleA(
(avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
(classifier): Linear(in_features=64, out_features=10, bias=True)
)
--------------------------------
Pruning Plan
--------------------------------
User pruning:
[ [DEP] ConvOutChannelPruner on stage_3.2.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => ConvOutChannelPruner on stage_3.2.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
Coupled pruning:
[ [DEP] ConvOutChannelPruner on stage_3.2.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => BatchnormPruner on stage_3.2.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 48}]
[ [DEP] BatchnormPruner on stage_3.2.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0) => ElementWiseOpPruner on _ElementWiseOp(ViewBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ViewBackward0) => LinearInChannelPruner on classifier (Linear(in_features=64, out_features=10, bias=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 240}]
[ [DEP] LinearInChannelPruner on classifier (Linear(in_features=64, out_features=10, bias=True)) => ElementWiseOpPruner on _ElementWiseOp(TBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_3.2.conv_a (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_3.1.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 48}]
[ [DEP] BatchnormPruner on stage_3.1.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_3.1.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_3.1.conv_a (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ConcatPruner on _ConcatOp([0, 32, 64]), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_3.0.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 48}]
[ [DEP] BatchnormPruner on stage_3.0.bn_b (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_3.0.conv_b (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31, 32, 34, 35, 37, 46, 47, 48, 50, 51, 53, 62, 63], metric={'#params': 13824}]
[ [DEP] ConcatPruner on _ConcatOp([0, 32, 64]) => ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ConcatPruner on _ConcatOp([0, 32, 64]) => ElementWiseOpPruner on _ElementWiseOp(MulBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_3.0.conv_a (Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 6912}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_2.2.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 24}]
[ [DEP] BatchnormPruner on stage_2.2.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_2.2.conv_b (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_2.2.conv_a (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_2.1.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 24}]
[ [DEP] BatchnormPruner on stage_2.1.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_2.1.conv_b (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_2.1.conv_a (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ConcatPruner on _ConcatOp([0, 16, 32]), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_2.0.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 24}]
[ [DEP] BatchnormPruner on stage_2.0.bn_b (BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_2.0.conv_b (Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15, 16, 18, 19, 21, 30, 31], metric={'#params': 3456}]
[ [DEP] ConcatPruner on _ConcatOp([0, 16, 32]) => ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ConcatPruner on _ConcatOp([0, 16, 32]) => ElementWiseOpPruner on _ElementWiseOp(MulBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AvgPool2DBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_2.0.conv_a (Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 1728}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_1.2.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 12}]
[ [DEP] BatchnormPruner on stage_1.2.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_1.2.conv_b (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_1.2.conv_a (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_1.1.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 12}]
[ [DEP] BatchnormPruner on stage_1.1.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_1.1.conv_b (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ElementWiseOpPruner on _ElementWiseOp(AddBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_1.1.conv_a (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => ElementWiseOpPruner on _ElementWiseOp(ReluBackward0), Index=[0, 2, 3, 5, 14, 15], metric={}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(AddBackward0) => BatchnormPruner on stage_1.0.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 12}]
[ [DEP] BatchnormPruner on stage_1.0.bn_b (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on stage_1.0.conv_b (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => BatchnormPruner on bn_1 (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 12}]
[ [DEP] ElementWiseOpPruner on _ElementWiseOp(ReluBackward0) => ConvInChannelPruner on stage_1.0.conv_a (Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 864}]
[ [DEP] BatchnormPruner on bn_1 (BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => ConvOutChannelPruner on conv_1_3x3 (Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[0, 2, 3, 5, 14, 15], metric={'#params': 162}]
Metric Sum: {'#params': 100890}
--------------------------------
# Output net
CifarResNet(
(conv_1_3x3): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(stage_1): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(16, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(stage_2): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(10, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): DownsampleA(
(avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(20, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(20, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(32, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(stage_3): Sequential(
(0): ResNetBasicblock(
(conv_a): Conv2d(20, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): DownsampleA(
(avg): AvgPool2d(kernel_size=1, stride=2, padding=0)
)
)
(1): ResNetBasicblock(
(conv_a): Conv2d(40, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResNetBasicblock(
(conv_a): Conv2d(40, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_a): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_b): Conv2d(64, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn_b): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
(classifier): Linear(in_features=40, out_features=10, bias=True)
)
BTW, if you do not like this behaviour, please replace the AvgPool2d (#L11 of your resnet_cifar10.py) with a standard Conv2d. The chained pruning will stop at this conv layer.
Thanks. I will try this one.
BTW, if you do not like this behaviour, please replace the AvgPool2d (#L11 of your resnet_cifar10.py) with a standard Conv2d. The chained pruning will stop at this conv layer.
Yes, using standard Conv2d it works as expected. Thank you for the good work.