🐛 [Bug] Unable to freeze tensor of type Int64/Float64 into constant layer
Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled
When I try to test the Transformer Attention layer with tensorRT, I get the error above. I do checked both the sample and input tensor and the inputs for trt.compile, there are no double tensor.
To Reproduce
Steps to reproduce the behavior:
- Just try with the following test code:
import torch
from torch import nn
import torch_tensorrt
from diffusers.models.attention import Attention
class AttnModule(nn.Module):
def __init__(self):
super().__init__()
num_attention_heads = 16
attention_head_dim = 8
dim = num_attention_heads * attention_head_dim
dropout = 0.0
attention_bias = False
upcast_attention = False
attention_out_bias = True
self.attn = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
def forward(self, sample: torch.Tensor):
return self.attn(sample)
model = AttnModule().to(device='cuda').eval() # torch module needs to be in eval (not training) mode
model = model.half()
a = torch.randn((1, 128, 128), device='cuda').half()
traced_model = torch.jit.trace(model, a).half().cuda()
print('traced_model', traced_model.graph)
enabled_precisions = {torch.half} # Run with fp16
with torch_tensorrt.logging.debug():
trt_ts_module = torch_tensorrt.compile(
#traced_model, inputs=[t debug=True,orch_tensorrt.Input((1, 128, 128), dtype=torch.half, name="sample")], enabled_precisions=enabled_precisions, truncate_long_and_double = True
traced_model, debug=True, ir="torchscript", inputs=[torch_tensorrt.Input([1, 128, 128], dtype=torch.half, name="sample")], enabled_precisions=enabled_precisions
)
c = torch.randn((1, 128, 128), device='cuda').half()
# warm up
model(c)
traced_model(c)
trt_ts_module(c)
import time
start = time.time()
for i in range(100):
with torch.no_grad():
result = model(c)
torch.cuda.synchronize()
print('cost0:', time.time() - start)
start = time.time()
for i in range(100):
with torch.no_grad():
result = traced_model(c)
torch.cuda.synchronize()
print('cost1:', time.time() - start)
start = time.time()
for i in range(100):
with torch.no_grad():
result = trt_ts_module(c)
torch.cuda.synchronize()
print('cost2:', time.time() - start)
Expected behavior
Code run correctly
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version : 2.2.0
- PyTorch Version : 2.2.1
- CPU Architecture: x86
- OS (e.g., Linux): Centos
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source): installed from https://download.pytorch.org/whl/cu121
- Are you using local sources or building from archives: No
- Python version: 3.10.4
- CUDA version: 12.1
- GPU models and configuration: Nvidia A10 24G
- Any other relevant information:
Additional context
When looking at your reproducer, I noticed that you had truncate_long_and_double enabled earlier but have it commented out? When I try running it through the torchscript frontend with that feature enabled on main seems like it works fine? Also if you are tracing to work around torchscript limitations you might want to use the dynamo frontend, if you still need torchscript at the end, you can torch.jit.trace the output of torch-tensorrt but you will have access to all the latest features we have been adding.
@narendasan Hi,thanks so much for your reply, If I enable the truncate_long_and_double there would be another dtype dismatch(float with half) error. And what confused me is that there are no double dtype in all tensor calculations. Also it would cost much more time than eager or torch.jit.trace mode when I am using dynamo frontend.
@narendasan Hi,thanks so much for your reply, If I enable the
truncate_long_and_doublethere would be another dtype dismatch(float with half) error. And what confused me is that there are no double dtype in all tensor calculations. Also it would cost much more time thaneagerortorch.jit.tracemode when I am usingdynamofrontend.
There may be int64 types in your code (including things like index) which require the use of that setting.