TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

✨[Feature] Introduce a detection mechanism for prim::If to fallback to avoid runtime error when sub-block is data-dependent

Open bowang007 opened this issue 3 years ago • 0 comments

Is your feature request related to a problem? Please describe. We have this graph:

INFO: [Torch-TensorRT] - Segment Block @171:
    Target: Torch

    Graph: graph(%1 : bool,
      %prev_features.1 : Tensor,
      %scale_factors0.1 : float[],
      %7 : int):
  %4 : NoneType = prim::Constant() # :0:0
  %8 : int = prim::Constant[value=4]()
  %top_down_features0 : Tensor = prim::If(%1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3913:4
    block0():
      %2 : Tensor = aten::upsample_nearest1d(%prev_features.1, %4, %scale_factors0.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3914:15
      -> (%2)
    block1():
      %6 : bool = aten::eq(%7, %8) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3915:7
      %9 : Tensor = prim::If(%6) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3915:4
        block0():
          %10 : Tensor = aten::upsample_nearest2d(%prev_features.1, %4, %scale_factors0.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3916:15
          -> (%10)
        block1():
          %11 : Tensor = aten::upsample_nearest3d(%prev_features.1, %4, %scale_factors0.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3918:15
          -> (%11)
      -> (%9)
  return (%top_down_features0)

In shape analysis phase, we will run all the blocks in prim::If node. However, In this case, %prev_features is a Tensor with specified size, so it will fail because torch_tensorrt is trying to run aten::upsample_nearest1d function with %prev_features's size as 2, this triggers torch dimension check failure in this line: https://github.com/pytorch/pytorch/blob/658f958bc4bb314d9c6030eeaf3e1784792b5d15/aten/src/ATen/native/UpSampleNearest1d.cpp#L11 In other words, torch-tensorrt will run aten::upsample_nearest1d, aten::upsample_nearest2d, aten::upsample_nearest3d in shape analysis, however, it's data-dependent so it will trigger runtime error if torchtrt runs all 3 sub-blocks at the same time. We should take care of this in future prim::If support for Torch-TensorRT.

Describe the solution you'd like Maybe we should introduce a detection mechanism and then fallback the entire prim::If node when we have this kind of case.

bowang007 avatar Sep 14 '22 21:09 bowang007