ChannelDependency does not handle concat properly
Describe the issue:
During computing Channel Dependencies reshape_break_channel_dependency does following code to ensure that the number of input channels equals the number of output channels:
in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape']
in_channel = in_shape[1]
out_channel = out_shape[1]
return in_channel != out_channel
This is correct for most reshape operations as long as they accept only one argument. In case of concatenation the in_shape is a list of the concatenated shapes and thus in_channel is being assigned a full shape (e.g. [1, 20, 32, 32]) instead of a single integer (e.g. 20).
This effectively prevents creation of channel dependencies caused by concatenations (although it's rather rare to concatenate feature maps in non-channel dimension).
Environment:
- NNI version: 2.5
- Training service local:
- Client OS: linux
- Server OS (for remote mode only):
- Python version: 3.8.6
- PyTorch/TensorFlow version: 1.10.0
- Is conda/virtualenv/venv used?: yes
- Is running in Docker?: no
Configuration: N/A
Log message: N/A
How to reproduce it?: Model code to quickly replicate the problem:
class NaiveModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(1, 20, 5, 1)
self.fc1 = torch.nn.Linear(6 * 6 * 40, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(4, 4)
self.max_pool2 = torch.nn.MaxPool2d(4, 4)
def forward(self, x):
x1 = self.relu1(self.conv1(x))
x1 = self.max_pool1(x1)
x2 = self.relu2(self.conv2(x))
x2 = self.max_pool2(x2)
x = torch.cat([x1, x2], 2)
x = x.view(-1, x.size()[1:].numel())
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = NaiveModel()
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
dummy_input = torch.ones([1, 1, 28, 28]).to(device)
config_list = [{"sparsity": 0.5, "op_types": ["Conv2d"]}]
pruner = L1FilterPruner(
model, config_list, dependency_aware=True, dummy_input=dummy_input
)
# just step into the code to see that the dependencies are not parsed correctly
@pkubik , thanks for your finding! we didn't take this situation into account, seems we need a patch for this function. If it is convenient, you could contribute a pr to fix it, or we will fix it in the next release.
@J-shang Can I take this up?
@J-shang Can I take this up?
Of course, any contributions are welcome~