✨[Feature] Introduce a detection mechanism for prim::If to fallback to avoid runtime error when sub-block is data-dependent
Is your feature request related to a problem? Please describe. We have this graph:
INFO: [Torch-TensorRT] - Segment Block @171:
Target: Torch
Graph: graph(%1 : bool,
%prev_features.1 : Tensor,
%scale_factors0.1 : float[],
%7 : int):
%4 : NoneType = prim::Constant() # :0:0
%8 : int = prim::Constant[value=4]()
%top_down_features0 : Tensor = prim::If(%1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3913:4
block0():
%2 : Tensor = aten::upsample_nearest1d(%prev_features.1, %4, %scale_factors0.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3914:15
-> (%2)
block1():
%6 : bool = aten::eq(%7, %8) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3915:7
%9 : Tensor = prim::If(%6) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3915:4
block0():
%10 : Tensor = aten::upsample_nearest2d(%prev_features.1, %4, %scale_factors0.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3916:15
-> (%10)
block1():
%11 : Tensor = aten::upsample_nearest3d(%prev_features.1, %4, %scale_factors0.1) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3918:15
-> (%11)
-> (%9)
return (%top_down_features0)
In shape analysis phase, we will run all the blocks in prim::If node. However, In this case, %prev_features is a Tensor with specified size, so it will fail because torch_tensorrt is trying to run aten::upsample_nearest1d function with %prev_features's size as 2, this triggers torch dimension check failure in this line: https://github.com/pytorch/pytorch/blob/658f958bc4bb314d9c6030eeaf3e1784792b5d15/aten/src/ATen/native/UpSampleNearest1d.cpp#L11
In other words, torch-tensorrt will run aten::upsample_nearest1d, aten::upsample_nearest2d, aten::upsample_nearest3d in shape analysis, however, it's data-dependent so it will trigger runtime error if torchtrt runs all 3 sub-blocks at the same time.
We should take care of this in future prim::If support for Torch-TensorRT.
Describe the solution you'd like Maybe we should introduce a detection mechanism and then fallback the entire prim::If node when we have this kind of case.