Gather Implementation
Fixes #2202
Had a doubt on this PR. Suppose we feed in int64 ITensor for the test case directly as below, the TRTInterpretor would complain saying int64 inputs are not allowed. Meanwhile torch.ops.aten.gather would require int64 inputs, how do we deal with those cases? Does there need to be explicitly casting there? The present test_gather_aten.py is failing. Eg cases like this -
class TestGatherConverter(DispatchTestCase):
def test_gather_zero_two_dim(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, indices):
out = torch.ops.aten.gather.default(x, 0, indices)
return out
index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32)
input = [torch.randn(2, 2), index0]
self.run_test(
TestModule(),
input,
)
The Dynamo paths both use a truncation mechanism for int64 inputs which takes effect prior to the TRTInterpreter to avoid that issue. Does Int32 not work for the test cases?
@gs-olive int32 in the test case leads to this error- RuntimeError: gather(): Expected dtype int64 for index
Ok got it - could the test then mimic the scheme used here:
https://github.com/pytorch/TensorRT/blob/16c031349c6a1af5a8408a817f2ef8542aa6f176/tests/py/dynamo/backend/test_backend_compiler.py#L193
This uses torch.compile directly to allow the 64-bit repair code to run.
Hi @gs-olive In the test case below:
class TestGatherConverter(DispatchTestCase):
def test_gather_zero_two_dim(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, indices):
# index0 = torch.randint(0, 1, (1, 1))
out = torch.ops.aten.gather.default(x, 0, indices)
return out
index0 = torch.randint(0, 1, (1, 1), dtype=torch.int64)
inputs = [torch.randn(2, 2), index0]
#index0 = torch.randint(0, 1, (1, 1))
fx_graph = torch.fx.symbolic_trace(TestModule())
torch._dynamo.reset()
optimized_model = torch_tensorrt.compile(fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
4,
f"TRT outputs don't match with the original model.",
)
The two cases
- If
truncate_long_and_doubleis kept as true, the error isgather(): Expected dtype int64 for index, but got torch.int32. That is because here would convert the int64 to int32, when the input reachesoptimized_model_results = optimized_model(*inputs).detach().cpu() - I don't think truncate_long_and_double=False would be a way since TensorRT would not accept int64 input Am I missing something in the workaround? Thanks.
2 Test Cases [ITensor + constant input types]
- Indices as input [Differing Semantics between PyTorch and TensorRT --> PyTorch gets Int64, TRT gets Int32]
- Known operator to assemble indices (i.e.
torch.arange, or other constant-producing layer)
- Known operator to assemble indices (i.e.
- Inlined constant indices test case as [Int64 Input Type]
"""
graph(x, indices):
--> inserted cast int64 to int32
# If running in PyTorch cast back to int64
out = aten.gather(x, 0, indices)
return out
"""
For a Later PR/Issue:
- Avoid running PyTorch graphs with invalid casts
- Refactor
repair_trunca...to consume output of type inference - Reorder truncation to come as late as possible in the compilation process
I removed the gather test here, since the above task was to expose the gather layer.
The TensorRT output for torch.ops.aten.gather and ctx.net.add_gather is different. Example
class gather(torch.nn.Module):
def __init__(self, indices):
self.indices = indices
super().__init__()
def forward(self, x):
out = torch.ops.aten.gather.default(x, 0, self.indices)
return out
index0 = torch.randint(0, 1, (1, 1), dtype=torch.int64).cuda()
inputs = torch.randn(2, 2).cuda()
gatherOut = gather(index0)
outTest = gatherOut(inputs)
index0 = [0]
dim = 0
Returns inputs[0,0], whereas the gather_layer would return [inputs[0,0], inputs[0,1]] (the zeroth tensor along the zeroth dimension
The expose of gather layer would be tested by aten_index and aten_select tests
Implementation of gather converter would be a separate task based on https://pytorch.org/docs/stable/generated/torch.gather.html