TensorRT
TensorRT copied to clipboard
skip dummy inference and run_shape_analysis
Description
There is two changes introduced in this PR:
-
during the compile stage: skipped dummy inference and use graph inspection instead to get the output_node.meta['val']
-
during the save stage: skipped run_shape_analysis and use graph inspection instead to get the output_node.meta['val']
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to not work as expected)
- This change requires a documentation update
Checklist:
- [ ] My code follows the style guidelines of this project (You can use the linters)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [ ] I have added tests to verify my fix or my feature
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
Here's what I think could be a simpler way of doing this
- We probably don't have to store output_shapes in
TorchTensorRTModuleclass. Once the compilation is finished, verify if the nodes of the TRT graph modules have metadata in them (if not we can update it by node.meta["val"] = original metadata) Reference: https://github.com/pytorch/TensorRT/blob/3eb48d786d403b12bd3700004c60e08c5c002f7b/py/torch_tensorrt/dynamo/_compiler.py#L496-L499
Here the node corresponding to _run_on_acc0 can be queried as
trt_module_node = [node for node in gm.graph.nodes if node.name == "_run_on_acc0"]
trt_module_node.meta["val"] - This should already have fake tensors which need to be used in the exporter.
- exporter We have the TRT module node here : https://github.com/pytorch/TensorRT/blob/3eb48d786d403b12bd3700004c60e08c5c002f7b/py/torch_tensorrt/dynamo/_exporter.py#L364 We could directly set (ensuring trt_module_node.meta["val"] always exists)
trt_node.meta["val"] = trt_module_node.meta["val"]
- infer_module_types We can replace the dummy inference with graph inspection by reading output metadata. The output of this function could be a list of FakeTensors and we can extract the dtypes from this to pass it to TRTInterpreter.
Replacing the dummy inference will also need changes to our converter test suite.