add runtime_platform_compatibility flag
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to not work as expected)
- This change requires a documentation update
Checklist:
- [ ] My code follows the style guidelines of this project (You can use the linters)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [ ] I have added tests to verify my fix or my feature
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
when I locally test it I got the following error on https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py#L150:
torch_tensorrt [TensorRT Conversion Context]:logging.py:24 IRuntime::deserializeCudaEngine: Error Code 1: Serialization (Serialization assertion plan->header.pad == expectedPlatformTag failed.Platform specific tag mismatch detected. TensorRT plan files are only supported on the target runtime platform they were created on.)
For this make sure to add the cross_compile flag as one of the engine invariant settings so that the cache doesnt try to pull them by accident
- torch_tensorrt.dynamo.cross_compile(model, inputs, *args, **kwargs) -> save EP to disk by calling torch_tensorrt.save()
Options:
- setup_engine -> execute_engine (Issue here is type is not supported in torch custom ops. For C++ ops, it is working). So POC with custom C++ ops should unblock this workflow.
- TRTEngine( check with the flag cross_compatibility)
- execute_serialized_engine() (downside: can't deserialize on every inference)