🐛 [Bug] Encountered bug for nn.LSTM with half dtype
Bug Description
When running the compiled LSTM model for half dtype with torch-tensorrt, I get this errors:
RuntimeError: Input and parameter tensors are not the same dtype, found input tensor with Float and parameter tensor with Half
Here is the test code:
import torch
import torch.nn as nn
import torch_tensorrt
class Model(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
self.lstm = nn.LSTM(800, 800)
def forward(self, input_x):
x = self.relu(input_x)
x = self.lstm(x)[0]
return x
model = Model().eval().half().cuda()
x = torch.randn(50, 50, 800).half().cuda()
script_model = torch.jit.trace(model, x)
trt_ts_model = torch_tensorrt.compile(script_model, ir="torchscript", inputs=[x], enabled_precisions=[torch.half], truncate_long_and_double=True)
res = trt_ts_model(x)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): v1.4.0
- PyTorch Version (e.g. 1.0): 2.0
- CPU Architecture:
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version: 11.8
- GPU models and configuration: A100
- Any other relevant information:
Additional context
This PR might be helpful for type issues: https://github.com/pytorch/TensorRT/pull/2469 However, I would recommend using Dynamo path since which is being actively supported right now. Thanks!
This PR might be helpful for type issues: #2469 However, I would recommend using Dynamo path since which is being actively supported right now. Thanks!
Thanks for your reply! It seems that this PR do not fix the issues and the bug is still active.