Torch-Pruning
Torch-Pruning copied to clipboard
Channel Sorting
@VainF As a feature request. Would it be possible to apply channel sorting based on channels? If you can provide any insights it would be really helpful.
# function to sort the channels from important to non-important
def get_input_channel_importance(weight):
in_channels = weight.shape[1]
importances = []
# compute the importance for each input channel
for i_c in range(weight.shape[1]):
channel_weight = weight.detach()[:, i_c]
importance = torch.norm(channel_weight)
importances.append(importance.view(1))
return torch.cat(importances)
@torch.no_grad()
def apply_channel_sorting(model):
model = copy.deepcopy(model) # do not modify the original model
# fetch all the conv and bn layers from the backbone
all_convs = [m for m in model.backbone if isinstance(m, nn.Conv2d)]
all_bns = [m for m in model.backbone if isinstance(m, nn.BatchNorm2d)]
# iterate through conv layers
for i_conv in range(len(all_convs) - 1):
# each channel sorting index, we need to apply it to:
# - the output dimension of the previous conv
# - the previous BN layer
# - the input dimension of the next conv (we compute importance here)
prev_conv = all_convs[i_conv]
prev_bn = all_bns[i_conv]
next_conv = all_convs[i_conv + 1]
# note that we always compute the importance according to input channels
importance = get_input_channel_importance(next_conv.weight)
# sorting from large to small
sort_idx = torch.argsort(importance, descending=True)
# apply to previous conv and its following bn
prev_conv.weight.copy_(torch.index_select(
prev_conv.weight.detach(), 0, sort_idx))
for tensor_name in ['weight', 'bias', 'running_mean', 'running_var']:
tensor_to_apply = getattr(prev_bn, tensor_name)
tensor_to_apply.copy_(
torch.index_select(tensor_to_apply.detach(), 0, sort_idx)
)
# apply to the next conv input (hint: one line of code)
next_conv.weight.copy_(
torch.index_select(next_conv.weight.detach(), 1, sort_idx))
return model
Hi @satabios, may I ask when we can use this feature? It would be great if there were some publications.
Found the paper:https://dl.acm.org/doi/abs/10.1145/3007787.3001163. Also it has been shown in various other papers and has also been tested here: https://github.com/satabios/sconce/blob/main/tutorials/Pruning.ipynb