Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

About L1_ Norm Sort

Open wxy1234567 opened this issue 5 years ago • 3 comments

Thank you for your wonderful work! It seems that there is no sparse training to determine how to select channels, but simply sort the weight. If the weight has positive and negative, go through the following code: L1_ norm = np.sum(weight, axis = (1,2,3)) It's also close to 0.

wxy1234567 avatar Jul 23 '20 03:07 wxy1234567

@wxy1234567 thanks for raising the issue. This is an error and the correct one should be

L1_norm = np.sum( np.abs(weight), axis = (1,2,3))

VainF avatar Jul 23 '20 04:07 VainF

Thank you for you reply! I found a problem in the process of pruning.I fixed the number of pruning, but it seems that the number of pruning is not correct. There is a double relationship. For example, I want to cut 10, but I cut 20 in the end.

import torch
from torchvision.models import resnet18
import torchvision
import torch_pruning as pruning
import numpy as np

def prune_model(model):
    model.cpu()
    DG = pruning.DependencyGraph().build_dependency( model, torch.randn(1, 3, 224, 224) )
    def prune_conv(conv, num_pruned):
        weight = conv.weight.detach().cpu().numpy()
        #out_channels = weight.shape[0]
        L1_norm = np.sum(np.abs(weight), axis=(1,2,3))
        #num_pruned = int(out_channels * pruned_prob)
        prune_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
        plan = DG.get_pruning_plan(conv, pruning.prune_conv, prune_index)
        plan.exec()
    
    block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    blk_id = 0
    for m in model.modules():
        if isinstance( m, torchvision.models.resnet.BasicBlock ):
            prune_conv( m.conv1, 10 )
            prune_conv( m.conv2, 10 )
            blk_id+=1
    return model  

model = resnet18(pretrained=True)
prune_model(model)
print(model)

wxy1234567 avatar Jul 23 '20 10:07 wxy1234567

This phenomena may be caused by the skip connections in resnet. Maybe you can refer to this issue.

VainF avatar Jul 23 '20 10:07 VainF