TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] nn.PixelShuffle crashes torch_tensorrt.compile

Open styler00dollar opened this issue 3 years ago • 2 comments

Bug Description

nn.PixelShuffle should work, but it does not.

To Reproduce

from torch import nn as nn
import torch

class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.upsampler = nn.PixelShuffle(4)

    def forward(self, x):
        x = self.conv(x)
        x = self.upsampler(x)
        return x

model = mynet()
model.eval().cuda()

example_data = torch.rand(1,3,64,64).cuda()
out = model(example_data)
print(out.shape)

model = torch.jit.trace(model, [example_data])
model = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input( \
                min_shape=(1, 3, 24, 24), \
                opt_shape=(1, 3, 256, 256), \
                max_shape=(1, 3, 512, 512), \
                dtype=torch.float32)], \
                enabled_precisions={torch.float}, truncate_long_and_double=True)
out = model(example_data)
print(out.shape)
torch.Size([1, 1, 256, 256])

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-12-6ec2adee4600> in <module>()
     21 
     22 model = torch.jit.trace(model, [example_data])
---> 23 model = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input(                 min_shape=(1, 3, 24, 24),                 opt_shape=(1, 3, 256, 256),                 max_shape=(1, 3, 512, 512),                 dtype=torch.float32)],                 enabled_precisions={torch.float}, truncate_long_and_double=True)
     24 out = model(example_data)
     25 print(out.shape)

1 frames

/usr/local/lib/python3.7/dist-packages/torch_tensorrt/_compile.py in compile(module, ir, inputs, enabled_precisions, **kwargs)
     95             )
     96             ts_mod = torch.jit.script(module)
---> 97         return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
     98     elif target_ir == _IRType.fx:
     99         raise RuntimeError("fx is currently not supported")

/usr/local/lib/python3.7/dist-packages/torch_tensorrt/ts/_compiler.py in compile(module, inputs, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, strict_types, capability, num_min_timing_iters, num_avg_timing_iters, workspace_size, max_batch_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules)
    117     }
    118 
--> 119     compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
    120     compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
    121     return compiled_module

RuntimeError: [Error thrown at core/conversion/conversionctx/ConversionCtx.cpp:162] Building serialized network failed in TensorRT

Expected behavior

No crash.

Environment

  • Torch-TensorRT Version: 1.0
  • PyTorch Version: torch==1.10.2+cu113
  • Python version: 3.7.12
  • CUDA version: 11.1
  • TensorRT: 8.2.3-1+cuda11.4

styler00dollar avatar Feb 01 '22 16:02 styler00dollar

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

github-actions[bot] avatar May 03 '22 00:05 github-actions[bot]

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

github-actions[bot] avatar Aug 17 '22 00:08 github-actions[bot]

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

github-actions[bot] avatar Dec 02 '22 00:12 github-actions[bot]