TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

Gather Implementation

Open apbose opened this issue 2 years ago • 7 comments

Fixes #2202

apbose avatar Nov 13 '23 20:11 apbose

Had a doubt on this PR. Suppose we feed in int64 ITensor for the test case directly as below, the TRTInterpretor would complain saying int64 inputs are not allowed. Meanwhile torch.ops.aten.gather would require int64 inputs, how do we deal with those cases? Does there need to be explicitly casting there? The present test_gather_aten.py is failing. Eg cases like this -

class TestGatherConverter(DispatchTestCase):
    def test_gather_zero_two_dim(self):
        class TestModule(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, indices):
                out = torch.ops.aten.gather.default(x, 0, indices)
                return out

        index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32)
        input = [torch.randn(2, 2), index0]
        self.run_test(
            TestModule(),
            input,
        )

apbose avatar Dec 19 '23 22:12 apbose

The Dynamo paths both use a truncation mechanism for int64 inputs which takes effect prior to the TRTInterpreter to avoid that issue. Does Int32 not work for the test cases?

gs-olive avatar Jan 05 '24 21:01 gs-olive

@gs-olive int32 in the test case leads to this error- RuntimeError: gather(): Expected dtype int64 for index

apbose avatar Jan 08 '24 20:01 apbose

Ok got it - could the test then mimic the scheme used here: https://github.com/pytorch/TensorRT/blob/16c031349c6a1af5a8408a817f2ef8542aa6f176/tests/py/dynamo/backend/test_backend_compiler.py#L193 This uses torch.compile directly to allow the 64-bit repair code to run.

gs-olive avatar Jan 08 '24 23:01 gs-olive

Hi @gs-olive In the test case below:

class TestGatherConverter(DispatchTestCase):
    def test_gather_zero_two_dim(self):
        class TestModule(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, indices):
                # index0 = torch.randint(0, 1, (1, 1))
                out = torch.ops.aten.gather.default(x, 0, indices)
                return out

        index0 = torch.randint(0, 1, (1, 1), dtype=torch.int64)
        inputs = [torch.randn(2, 2), index0]
        #index0 = torch.randint(0, 1, (1, 1))
        fx_graph = torch.fx.symbolic_trace(TestModule())
        torch._dynamo.reset()

        optimized_model = torch_tensorrt.compile(fx_graph, 
            "torch_compile",
            inputs,
            min_block_size=1,
            pass_through_build_failures=True,
            truncate_long_and_double=True,
            debug=True,
        )
        optimized_model_results = optimized_model(*inputs).detach().cpu()
        torch_model_results = fx_graph(*inputs).detach().cpu()
        max_diff = float(
            torch.max(torch.abs(optimized_model_results - torch_model_results))
        )
        self.assertAlmostEqual(
            max_diff,
            0,
            4,
            f"TRT outputs don't match with the original model.",
        )

The two cases

  1. If truncate_long_and_double is kept as true, the error is gather(): Expected dtype int64 for index, but got torch.int32. That is because here would convert the int64 to int32, when the input reaches optimized_model_results = optimized_model(*inputs).detach().cpu()
  2. I don't think truncate_long_and_double=False would be a way since TensorRT would not accept int64 input Am I missing something in the workaround? Thanks.

apbose avatar Jan 12 '24 15:01 apbose

2 Test Cases [ITensor + constant input types]

  • Indices as input [Differing Semantics between PyTorch and TensorRT --> PyTorch gets Int64, TRT gets Int32]
    • Known operator to assemble indices (i.e. torch.arange, or other constant-producing layer)
  • Inlined constant indices test case as [Int64 Input Type]
"""
graph(x, indices):
    --> inserted cast int64 to int32
    # If running in PyTorch cast back to int64
    out = aten.gather(x, 0, indices)

    return out
"""

For a Later PR/Issue:

  • Avoid running PyTorch graphs with invalid casts
  • Refactor repair_trunca... to consume output of type inference
  • Reorder truncation to come as late as possible in the compilation process

gs-olive avatar Jan 12 '24 19:01 gs-olive

I removed the gather test here, since the above task was to expose the gather layer. The TensorRT output for torch.ops.aten.gather and ctx.net.add_gather is different. Example

class gather(torch.nn.Module):
    def __init__(self, indices):
        self.indices = indices
        super().__init__()

    def forward(self, x):
        out = torch.ops.aten.gather.default(x, 0, self.indices)
        return out

index0 = torch.randint(0, 1, (1, 1), dtype=torch.int64).cuda()
inputs = torch.randn(2, 2).cuda()
gatherOut = gather(index0)
outTest = gatherOut(inputs)

index0 = [0] dim = 0 Returns inputs[0,0], whereas the gather_layer would return [inputs[0,0], inputs[0,1]] (the zeroth tensor along the zeroth dimension The expose of gather layer would be tested by aten_index and aten_select tests Implementation of gather converter would be a separate task based on https://pytorch.org/docs/stable/generated/torch.gather.html

apbose avatar Feb 20 '24 20:02 apbose