Error when pruning a group with GLU Layer
Pruning a model with GLU results in an error when finding importance. GLU does not have any params but halves the input (in the given dimension). This is not accounted for during tracing, assigning indices, and finding importance.
Here's a minimal example with a simple model
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Conv1d(in_channels=96, out_channels=24, kernel_size=3, padding=1, dilation=1),
nn.GLU(1),
nn.Conv1d(in_channels=12, out_channels=48, kernel_size=3, padding=2, dilation=2),
)
def forward(self, x):
return self.layer(x)
model = MyModel()
example_inputs = torch.randn(1, 96, 100)
I then prune using GroupNormPruner
imp = tp.importance.GroupNormImportance()
pruner = tp.pruner.GroupNormPruner(
model,
example_inputs,
importance=imp,
iterative_steps=1,
pruning_ratio=0.5,
ignored_layers=[],
)
pruner.step()
This gives an index out of bounds error
Exception has occurred: IndexError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
index 12 is out of bounds for dimension 0 with size 12
File "/masked_path/lib/python3.9/site-packages/torch_pruning/pruner/importance.py", line 205, in __call__
local_imp = local_imp[idxs]
File "/masked_path/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/masked_path/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 247, in estimate_importance
return self.importance(group)
File "/masked_path/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 362, in prune_local
imp = self.estimate_importance(group)
File "/masked_path/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 228, in step
for group in pruning_method():
File "masked_path/my_file.py", line 39, in <module>
pruner.step()
File "/masked_path/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/masked_path/lib/python3.9/runpy.py", line 197, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
IndexError: index 12 is out of bounds for dimension 0 with size 12
Note that this error is raised when returning the group, so setting interative=True in pruner step does not help.
GLU is currently not supported, so it's treated as an element-wise operation. However, since split is supported, you can create your own GLU operation like this:
class CustomGLU(nn.Module):
def __init__(self, dim=1):
super(CustomGLU, self).__init__()
self.dim = dim
def forward(self, x):
first_half, second_half = torch.split(x, x.size(self.dim)//2, dim=self.dim)
return first_half * torch.sigmoid(second_half)
If you don't use it extensively, the performance degradation shouldn't be significant.
Hi @janthmueller, thanks for the workaround, I tried that but the network comes back with only the last conv layer pruned. No dep group with first conv layer is being returned for pruning.
Hi @janthmueller, thanks for the workaround, I tried that but the network comes back with only the last conv layer pruned. No dep group with first conv layer is being returned for pruning.
After running the get_pruning_group method within the prune_local function of the MetaPruner class, you might notice that the group containing the first layer appears to have double the number of indices. This likely occurs to prevent shape mismatch errors. However, with a pruning ratio of 0.5, attempting to prune the entire output of the first layer becomes impossible. This is because a group is ignored for pruning if all its filters or channels are pruned, resulting in nothing being pruned in your case.
To accommodate this scenario, it's crucial to apply a targeted adjustment before gathering the pruning_idxs. Specifically, for groups involving the custom glu operation, a workaround involves halving the number of pruned indices (n_pruned) for the affected group. This ensures that the pruning process correctly reflects the intended proportion.
To implement this adjustment, insert the following code snippet before collecting pruning_idxs within both the prune_local and prune_global methods:
for dep, _ in group:
if isinstance(dep.target.module, ops._SplitOp):
n_pruned = n_pruned // 2
break
By incorporating this adjustment, the pruning mechanism can appropriately handle scenarios involving the custom glu operation, ensuring accurate pruning outcomes.
I think it might be best to fix this for all possible scenarios including a split, maybe similar to _is_attn_group with a _is_split_group check @VainF.
Great, thanks for the workaround and the explanation!
It would be great to have this merged such that the lib works directly on GLU!