🐛 [Bug] interpolation doesn't support some arguments setting
Bug Description
Depending on a number of arguments, interpolation doesn't work.
# return torch.nn.functional.interpolate(img, scale_factor=2.0) #ok
# return torch.nn.functional.interpolate(img, scale_factor=(2.0, 2.0)) #compilation error
# return torch.nn.functional.interpolate(img, size=200) #ok
# return torch.nn.functional.interpolate(img, size=(200, 200)) #compilation error
To Reproduce
$ python3 test_torchrt.py
Traceback (most recent call last):
File "test_torchrt.py", line 30, in
import torch
import torch_tensorrt
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
def forward(self, img: torch.Tensor) -> torch.Tensor:
# return torch.nn.functional.interpolate(img, scale_factor=2.0) #ok
# return torch.nn.functional.interpolate(img, scale_factor=(2.0, 2.0)) #compilation error
# return torch.nn.functional.interpolate(img, size=200) #ok
return torch.nn.functional.interpolate(img, size=(200, 200)) #compilation error
height = 100
width = 100
shape = [1, 3, height, width]
compile_settings = {
"inputs": [torch_tensorrt.Input(
shape=shape,
dtype=torch.float,
)],
"enabled_precisions": {torch.float},
}
img = torch.rand(shape).to("cuda")
jit_module = torch.jit.script(TestModule())
jit_module = torch_tensorrt.ts.compile(jit_module, **compile_settings)
img = jit_module(img)
Environment
- Torch-TensorRT Version (1.1.0):
- PyTorch Version (1.11.0):
- OS (Linux):
- How you installed PyTorch (pip):
- Python version: 3.8
- CUDA version: 11.3
- GPU models and configuration: RTX2070
I am facing the exact same issue, is there a workaround by any chance?
The same issue
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days
Hello - I have tested out the sample script with the latest release of Torch-TRT (1.3.0), and both of the previously-failing compilations are now succeeding. Please let me know if the new release resolves this for you as well.
Thanks for fixing it