RuntimeError: offset output dims: (160, 160) - computed output dims: (80, 80)
🐛 Describe the bug
I run source build pytorch(torch==1.10.0) and torchvision(torchvision==0.12.0). train code snippet:
from torchvision.ops import DeformConv2d as ModulatedDeformConv
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
planes, deformable_groups * offset_channels,
kernel_size=3,
padding=1)
self.conv2 = conv_op(
planes, planes, kernel_size=3, padding=1, stride=stride,
groups=deformable_groups, bias=False)
offset_mask = self.conv2_offset(out)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, -9:, :, :].sigmoid()
out = self.conv2(out, offset, mask)
tensor shape below:
(Pdb) offset_mask.shape
torch.Size([16, 27, 160, 160])
(Pdb) offset.shape
torch.Size([16, 18, 160, 160])
(Pdb) mask.shape
torch.Size([16, 9, 160, 160])
(Pdb) self.conv2.__doc__
'\n See :func:`deform_conv2d`.\n '
(Pdb) out.shape
torch.Size([16, 128, 160, 160])
When training(run out = self.conv2(out, offset, mask) this line), issue the error below, pls help me, tks.
Traceback (most recent call last):
File "/DB/train.py", line 89, in <module>
main()
File "/DB/train.py", line 86, in main
trainer.train()
File "/DB/trainer.py", line 115, in train
epoch=epoch, step=self.steps)
File "/DB/trainer.py", line 138, in train_step
results = model.forward(batch, training=True)
File "/DB/structure/model.py", line 56, in forward
pred = self.model(data, training=self.training)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
return self.module(*inputs, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/DB/structure/model.py", line 19, in forward
return self.decoder(self.backbone(data), *args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/DB/backbones/resnet.py", line 242, in forward
x3 = self.layer2(x2)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/DB/backbones/resnet.py", line 161, in forward
out = self.conv2(out, offset, mask)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torchvision-0.12.0a0+031e129-py3.7-linux-x86_64.egg/torchvision/ops/deform_conv.py", line 177, in forward
mask=mask,
File "/opt/conda/lib/python3.7/site-packages/torchvision-0.12.0a0+031e129-py3.7-linux-x86_64.egg/torchvision/ops/deform_conv.py", line 106, in deform_conv2d
use_mask,
RuntimeError: offset output dims: (160, 160) - computed output dims: (80, 80)
tks~
Versions
Collecting environment information... PyTorch version: 1.10.0a0+git36449ea Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64) GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0 Clang version: Could not collect CMake version: version 3.21.4 Libc version: glibc-2.10
Python version: 3.7.6 (default, Jan 8 2020, 19:59:22) [GCC 7.3.0] (64-bit runtime) Python platform: Linux-5.4.0-86-generic-x86_64-with-debian-buster-sid Is CUDA available: False CUDA runtime version: No CUDA GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==1.10.0+cpu
[pip3] numpy==1.21.2
[pip3] torch==1.10.0a0+git36449ea
[pip3] torchvision==0.12.0a0+031e129
[conda] blas 1.0 mkl
[conda] intel-extension-for-pytorch 1.10.0+cpu pypi_0 pypi
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-include 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py37h7f8727e_0
[conda] mkl_fft 1.3.1 py37hd3c417c_0
[conda] mkl_random 1.2.2 py37h51133e4_0
[conda] numpy 1.21.2 py37h20f2e39_0
[conda] numpy-base 1.21.2 py37h79a1101_0
[conda] torch 1.10.0a0+git36449ea pypi_0 pypi
[conda] torchvision 0.12.0a0+031e129 pypi_0 pypi
Is anyone here?
@smartkiwi @kashif @ezyang @djsutherland Can you help me?
@wduo Thanks for raising this.
Unfortunately the information included on the bug report is not enough to reproduce the problem. If you submit a snippet that fully reproduces the problem, it would be easier for us to help.
Having said that, it seems to me that either he mask or the offset dimensions don't match the expected ones. Unfortunately it's not possible to tell which one fails because the error message is the same (I'll send a PR in a bit to fix this): https://github.com/pytorch/vision/blob/ae87c1e46df6fb404654935c82e15013d56b7aa8/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L950-L975
Your snippet doesn't contain information about all parameters (stride, deformable_groups etc), so it's hard to point out what's wrong but my guess is that you use a stride of 2 which means the output dims are expected to be different to what you provide. The estimation of the dimensions can be seen here:
https://github.com/pytorch/vision/blob/ae87c1e46df6fb404654935c82e15013d56b7aa8/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L907-L910
@datumbox When I was training, I called DB/backbones/resnet.py L298,and then replaced L55 and L126 with from torchvision.ops import DeformConv2d as ModulatedDeformConv。
Please help me take a look. If you need other relevant information, you can call me at any time. tks ha.
Because the original code in the DB uses the cuda version of dcn and I only have a CPU, I want to use the dcn operator in torchvision.
@wduo Unfortunately, it's hard to help without a minimal snippet with no external dependencies that reproduces the problem. If that's something you can't provide, I recommend ensuring that the offset and mask dimensions you provide to the layer match the expected values of the kernel listed. I posted the snippet that estimates them for the CPU kernel above but the same happens on GPU.
@wduo I had the same problem when I was using DB++ and doing the same replaced(L55 and L126 with from torchvision.ops import DeformConv2d as ModulatedDeformConv) ubuntu 20.04 pytorch 1.8.2 cuda 11.1
to solve, I modified L58 aand L129 (add stride=stride for self.conv2_offset=nn.Conv2d(...)) self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1) to self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1, stride=stride)
Then it works fine, the pretrained model loading is ok and model prediction result is also correct But I don't know exactly what the reason is and whether this modification is completely correct