Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Default pruning_dimension is not supported for non tensor example inputs

Open hovavalon opened this issue 4 years ago • 0 comments

From build_dependency:

if pruning_dim >= 0:
    pruning_dim = pruning_dim - len(example_inputs.size())

pruning_dim is 1 by default, and even though using a list or a dictionary of inputs is supported, here an exceptions occurs since example_inputs is a list and doesn't have a size attribute.

I think it could be fixed this way:

if isinstance(example_inputs, torch.Tensor):
    pruning_dim = pruning_dim - len(example_inputs.size())
elif isinstance(example_inputs, (tuple, list)):
    pruning_dim = pruning_dim - len(example_inputs[0].size())
else:
    raise Exception("pruning with non negative dimension is not supported for input of type {}".format(str(type(example_inputs))))

If anyone familiar with the DependencyGraph's code has a better idea I would be glad to hear about it.

hovavalon avatar Jan 18 '22 07:01 hovavalon