Torch-Pruning
Torch-Pruning copied to clipboard
Default pruning_dimension is not supported for non tensor example inputs
From build_dependency:
if pruning_dim >= 0:
pruning_dim = pruning_dim - len(example_inputs.size())
pruning_dim is 1 by default, and even though using a list or a dictionary of inputs is supported, here an exceptions occurs since example_inputs is a list and doesn't have a size attribute.
I think it could be fixed this way:
if isinstance(example_inputs, torch.Tensor):
pruning_dim = pruning_dim - len(example_inputs.size())
elif isinstance(example_inputs, (tuple, list)):
pruning_dim = pruning_dim - len(example_inputs[0].size())
else:
raise Exception("pruning with non negative dimension is not supported for input of type {}".format(str(type(example_inputs))))
If anyone familiar with the DependencyGraph's code has a better idea I would be glad to hear about it.