About L1_ Norm Sort
Thank you for your wonderful work! It seems that there is no sparse training to determine how to select channels, but simply sort the weight. If the weight has positive and negative, go through the following code: L1_ norm = np.sum(weight, axis = (1,2,3)) It's also close to 0.
@wxy1234567 thanks for raising the issue. This is an error and the correct one should be
L1_norm = np.sum( np.abs(weight), axis = (1,2,3))
Thank you for you reply! I found a problem in the process of pruning.I fixed the number of pruning, but it seems that the number of pruning is not correct. There is a double relationship. For example, I want to cut 10, but I cut 20 in the end.
import torch
from torchvision.models import resnet18
import torchvision
import torch_pruning as pruning
import numpy as np
def prune_model(model):
model.cpu()
DG = pruning.DependencyGraph().build_dependency( model, torch.randn(1, 3, 224, 224) )
def prune_conv(conv, num_pruned):
weight = conv.weight.detach().cpu().numpy()
#out_channels = weight.shape[0]
L1_norm = np.sum(np.abs(weight), axis=(1,2,3))
#num_pruned = int(out_channels * pruned_prob)
prune_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
plan = DG.get_pruning_plan(conv, pruning.prune_conv, prune_index)
plan.exec()
block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
blk_id = 0
for m in model.modules():
if isinstance( m, torchvision.models.resnet.BasicBlock ):
prune_conv( m.conv1, 10 )
prune_conv( m.conv2, 10 )
blk_id+=1
return model
model = resnet18(pretrained=True)
prune_model(model)
print(model)
This phenomena may be caused by the skip connections in resnet. Maybe you can refer to this issue.