Add option to compute pruning groups based on in_channels in compute_all_groups
Current behavior
In dependency/graph.py#L276, pruning groups are always computed based on output channels when calling DependencyGraph.compute_all_groups, which seems like an arbitrary choice.
Issue
For many newer architectures (e.g., LLMs, MLP blocks), pruning is often expressed more naturally in terms of input channels; e.g., pruning in_features of a down_proj matrix.
Currently, compute_all_groups does not provide a way to group dependencies based on input channels.
Why this hasn’t been a blocker
One can still compute pruning groups for the layers of interest "manually" using
DependencyGraph.get_pruning_groupif ```input_channels`` need to be pruned.
- However, this requires providing the right pruning function.
In many cases, pruning out_channels implicitly prunes corresponding in_channels due to coupling across dependencies.
- Having an option to directly have groups based on in_channels allows for better downstream applications of the group. e.g., finding the weights of the module whose input_channels are being pruned to compute new saliencies.
Proposal
I’d like to add an option to compute_all_groups (e.g., a boolean flag) that lets users choose whether groups are computed with respect to out_channels (default, current behavior) or in_channels.
This would make the API more flexible for different pruning strategies.
The choice of computing all groups only based on the output channel seems to be an arbitrary decision, and the API should support what type of groups to compute.
I’m happy to open a PR for this, but wanted to get the maintainers’ feedback first.
Added PR: https://github.com/VainF/Torch-Pruning/pull/521