🐛 [Bug] Can't load UNet on H100 after compiling ExportedProgram with torch_tensorrt.dynamo.compile and saving
Bug Description
I am trying to use torch_tensorrt.dynamo.compile() to AOT compile the UNet portion of a StableDiffusionPipeline from the diffusers library (version 0.30.2). I am able to export the UNet with torch.export.export(), compile it with torch_tensorrt.dynamo.compile() and save it with torch_tensorrt.save(). However, I am encountering a runtime error when attempting to load the saved compiled UNet with torch.export.load().
To Reproduce
Run the code below
import functools
import torch
import torch_tensorrt
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
def generate_sd_unet_inputs():
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
timestep = torch.rand([], device="cuda", dtype=torch.float32) * 999
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
return sample, timestep, encoder_hidden_states
with torch.inference_mode():
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
unet_model = pipe.unet.eval()
unet_model.forward = functools.partial(unet_model.forward, return_dict=False)
arg_inputs_unet = generate_sd_unet_inputs()
expected_outputs_unet = unet_model(*arg_inputs_unet)
unet_exported_program = torch.export.export(unet_model, arg_inputs_unet)
with torch_tensorrt.logging.errors():
compiled_unet = torch_tensorrt.dynamo.compile(
unet_exported_program,
inputs=arg_inputs_unet,
enabled_precisions={torch.float16, torch.float32},
truncate_double=True,
)
torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
loaded_unet = torch.export.load("sd_unet_compiled.ep").module()
Error message
...
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:370: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
engine_node = gm.graph.get_attr(engine_name)
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_2_engine target _run_on_acc_2_engine _run_on_acc_2_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_4_engine target _run_on_acc_4_engine _run_on_acc_4_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_6_engine target _run_on_acc_6_engine _run_on_acc_6_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_8_engine target _run_on_acc_8_engine _run_on_acc_8_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1593: UserWarning: Additional 16 warnings suppressed about get_attr references
warnings.warn(
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 48
40 compiled_unet = torch_tensorrt.dynamo.compile(
41 unet_exported_program,
42 inputs=arg_inputs_unet,
43 enabled_precisions={torch.float16, torch.float32},
44 truncate_double=True,
45 )
47 torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
---> 48 loaded_unet = torch.export.load("sd_unet_compiled.ep")
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py:476](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py#line=475), in load(f, extra_files, expected_opset_version)
468 artifact: SerializedArtifact = SerializedArtifact(
469 serialized_exported_program,
470 serialized_state_dict,
471 serialized_constants,
472 serialized_example_inputs,
473 )
475 # Deserialize ExportedProgram
--> 476 ep = deserialize(artifact, expected_opset_version)
478 return ep
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2437](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py#line=2436), in deserialize(artifact, expected_opset_version)
2433 exported_program_dict = json.loads(exported_program_str)
2434 serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict)
2435 return (
2436 ExportedProgramDeserializer(expected_opset_version)
-> 2437 .deserialize(
2438 serialized_exported_program,
2439 artifact.state_dict,
2440 artifact.constants,
2441 artifact.example_inputs,
2442 )
2443 )
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2329](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py#line=2328), in ExportedProgramDeserializer.deserialize(self, exported_program, state_dict, constants, example_inputs)
2314 res = (
2315 GraphModuleDeserializer()
2316 .deserialize(
(...)
2322 )
2323 )
2324 range_constraints = self.deserialize_range_constraints(
2325 symbol_name_to_range,
2326 res.names_to_symbols,
2327 )
-> 2329 return ep.ExportedProgram(
2330 root=res.graph_module,
2331 graph=res.graph_module.graph,
2332 graph_signature=res.signature,
2333 state_dict=res.state_dict, # type: ignore[arg-type]
2334 range_constraints=range_constraints,
2335 module_call_graph=res.module_call_graph,
2336 example_inputs=res.example_inputs,
2337 constants=res.constants,
2338 verifiers=[load_verifier(v) for v in exported_program.verifiers],
2339 )
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:700](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=699), in ExportedProgram.__init__(self, root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, constants, verifiers)
698 self._verifiers = verifiers
699 # Validate should be always the last step of the constructor.
--> 700 self.validate()
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1117](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=1116), in ExportedProgram.validate(self)
1115 @compatibility(is_backward_compatible=False)
1116 def validate(self):
-> 1117 self._validate()
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1126](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=1125), in ExportedProgram._validate(self)
1122 assert (
1123 len(self.verifiers) > 0
1124 ), "ExportedProgram must have at least one verifier."
1125 for v in self.verifiers:
-> 1126 v().check(self)
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:155](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py#line=154), in Verifier.check(self, ep)
153 @final
154 def check(self, ep: "ExportedProgram") -> None:
--> 155 self._check_graph_module(ep.graph_module)
156 _verify_exported_program_module_call_graph(ep)
157 _verify_exported_program_signature(ep)
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:214](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py#line=213), in Verifier._check_graph_module(self, gm)
211 if not isinstance(mod, torch.fx.GraphModule):
212 continue
--> 214 mod.graph.lint()
215 for node in mod.graph.nodes:
216 # TODO(T140410192): should have fake tensor for all dialects
217 if node.op in {"call_module", "call_method"}:
File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1549](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py#line=1548), in Graph.lint(self)
1546 seen_values.add(node)
1548 if node.name in seen_names:
-> 1549 raise RuntimeError(f'Node redefined name {node.name}!')
1550 seen_names.add(node.name)
1552 # Check targets are legit
RuntimeError: Node redefined name getitem_130!
Expected behavior
The code should load the saved compiled model without erroring out.
Environment
- Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240912+cu124
- PyTorch Version (e.g. 1.0): 2.5.0.dev20240912+cu124
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu 22.04.4 LTS (x86_64)
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.11.10
- CUDA version: 12.4
- GPU models and configuration: 1/2 of an H100 (Configured with MIG)
- Any other relevant information: Using diffusers version 0.30.2
Additional context
I have to use functools.partial() in the code above because the default output of the pipeline's forward method is the UNet2DConditionOutput dataclass. I tried to get rid of functools.partial() by instead using torch.export.register_dataclass() but was met with the same runtime error mentioned above.
Additionally, saving and loading the ExportedProgram (without Torch-TensorRT compilation) works fine.
@readleyj I tried in my environment from today's latest main branch using RTX 4080, I don't get the error as you pasted. I can successfully load the unet.
@lanluo-nvidia Thank you for the reply. That is very strange. I will try with today's nightly and report back. Also, I am running this on an H100, could that possibly be the source of the issue?
I tried again with today's nightly (torch_tensorrt==2.5.0.dev20240918+cu124, torch==dev20240912+cu124) and I am encountering the same runtime error. Additionally, the results for the compiled UNet match the original UNet. At this point, I am not sure if the issue is with Torch-TensorRT or torch.export.
I also tried with release 2.4. There, I can successfully save and load the model but the compiled model outputs are full of nans. In general, Stable Diffusion with Torch-TensorRT seems very problematic.
@readleyj yes, we have bugs in release 2.4 which got fixed in current main branch, if you could paste the code: after loaded the unet how do you generate the image. I will give a try also.
@lanluo-nvidia After loading the UNet, I first check if the results match (expected_outputs_unet is defined in the previous code block)
with torch.inference_mode():
tensorrt_outputs_unet = loaded_unet(*arg_inputs_unet)
for expected_output, tensorrt_output in zip(expected_outputs_unet, tensorrt_outputs_unet):
assert torch.allclose(
expected_output, tensorrt_output, 1e-2, 1e-2
), "UNet results do not match"
print("UNet results match for Torch-TensorRT and Diffusers")
To generate an image, I plug the loaded UNet into a StableDiffusion pipeline as follows (code block assumes loaded_unet is already defined):
import torch
import torch_tensorrt
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
PROMPT = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
with torch.inference_mode():
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
).to("cuda")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
class LoadedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.config.in_channels
setattr(self, "config", pipe.unet.config)
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states, **kwargs):
sample = loaded_unet(latent_model_input, t, encoder_hidden_states)
return sample
pipe.unet = LoadedUNet()
image = pipe(PROMPT,
num_inference_steps=50,
height=512,
width=512,
).images[0]
Note that you may receive a Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and / or seed. warning from diffusers. This happens when the image is all black or gibberish.
@readleyj I have tried with release/2.5 branch: (this is our upcoming release branch and it is more stable then main branch since main branch is getting all the latest changes from both pytorch and torch_tensorrt) Test 1): tested locally in my RTX4080 with release/2.5 branch
python -m pip install --pre --editable . --extra-index-url https://download.pytorch.org/whl/test/cu124
python -m pip install --pre “torchvision>=0.20.0,<0.21.0” --index-url https://download.pytorch.org/whl/test/cu124
python /home/lanl/git/script/python/stable_diffusion/test_issue3163.py
I can see it does throw out the results does not match error (the rtol atol is actually very close to 1e -2, 1e-2), also does not see full of nans and also, if I change the rtol, atol to 1e-2, 5e-2, it is able to generate the image as expected:
lan added expected_outputs_unet=(tensor([[[[-3.0304e-02, 4.7119e-01, 1.3281e-01, ..., 3.9185e-01,
-1.1200e-01, 1.3062e-01],
[-2.1191e-01, 2.7026e-01, -2.9688e-01, ..., -3.6548e-01,
9.2712e-02, -6.3867e-01],
[-5.2441e-01, 1.8555e-02, 9.9414e-01, ..., -1.0977e+00,
3.9429e-02, -6.8164e-01],
...,
[-6.3574e-01, -2.2583e-01, 3.3234e-02, ..., 2.8491e-01,
-4.9316e-01, 9.1455e-01],
[ 5.8594e-01, -7.3779e-01, 6.9695e-03, ..., 7.1094e-01,
-3.3569e-01, 1.2830e-01],
[-2.4805e-01, -1.2152e-01, 8.3643e-01, ..., -2.8641e-02,
1.9739e-01, 1.3367e-01]],
[[ 2.7319e-01, 4.0063e-01, -4.8682e-01, ..., -6.2939e-01,
2.1790e-02, 4.9634e-01],
[-9.3323e-02, -7.7393e-01, 2.1399e-01, ..., 7.6953e-01,
4.5410e-01, -3.1909e-01],
[ 1.8079e-01, 4.2017e-01, 1.1699e+00, ..., 1.3843e-01,
-3.2898e-02, -1.3953e-01],
...,
[-1.4839e-02, 1.2131e-02, 1.7859e-01, ..., 5.9717e-01,
8.0762e-01, -7.5684e-01],
[ 6.3232e-01, -6.1035e-01, 1.9214e-01, ..., -3.3496e-01,
4.9048e-01, 7.0166e-01],
[ 2.3340e-01, -6.1279e-01, 7.8271e-01, ..., -1.9067e-01,
-6.3965e-01, -1.7529e-01]],
[[-8.2324e-01, -4.7180e-02, -8.0383e-02, ..., -6.2109e-01,
4.5319e-02, 1.5930e-01],
[ 1.0908e+00, -7.1143e-01, 9.6484e-01, ..., 4.6777e-01,
-2.4548e-01, -5.6445e-01],
[ 2.2278e-01, 1.2256e+00, 3.4302e-01, ..., -3.1372e-01,
3.3203e-01, 1.1426e-01],
...,
[ 1.7578e-01, -2.4002e-02, 3.9581e-02, ..., 1.4160e-01,
2.4902e-01, -2.7515e-01],
[ 6.4893e-01, -1.7891e+00, 3.4570e-01, ..., 3.9868e-01,
5.0977e-01, 5.0146e-01],
[ 2.1948e-01, -2.0020e-01, 3.3862e-01, ..., -2.5488e-01,
7.9346e-02, -3.8794e-01]],
[[-3.6206e-01, -5.7080e-01, -7.8369e-02, ..., 4.7388e-01,
4.5093e-01, -2.6636e-01],
[ 4.5630e-01, 4.8340e-01, 5.4053e-01, ..., -2.9175e-01,
2.3331e-02, -5.2979e-01],
[ 4.5728e-01, -3.1177e-01, -1.5879e+00, ..., -1.6748e-01,
1.8408e-01, -3.1592e-01],
...,
[-2.6074e-01, 1.6028e-01, -5.9766e-01, ..., 2.4963e-01,
2.9688e-01, -1.1699e+00],
[-2.1367e+00, 5.9619e-01, 6.1133e-01, ..., -3.5962e-01,
-4.8193e-01, 1.5167e-02],
[ 9.3018e-01, 5.7471e-01, -4.0332e-01, ..., 5.8691e-01,
-1.6826e+00, -4.2450e-02]]],
[[[-1.4170e+00, 3.9453e-01, -4.8438e-01, ..., 2.2180e-01,
4.1724e-01, 9.6252e-02],
[ 1.5015e-01, 4.6851e-01, 3.3643e-01, ..., 5.3467e-02,
-1.9666e-01, -9.2773e-02],
[ 1.0840e+00, 5.0244e-01, 8.7695e-01, ..., 3.6957e-02,
-1.0840e+00, -7.1436e-01],
...,
[-7.7100e-01, 2.4207e-01, -3.6084e-01, ..., -6.8298e-02,
-2.1643e-01, 1.4391e-03],
[-3.7964e-01, -2.0032e-01, 4.6173e-02, ..., -2.1252e-01,
2.0972e-01, -6.0608e-02],
[ 3.5840e-01, -1.3125e+00, 4.1528e-01, ..., -6.7871e-01,
9.4434e-01, -3.8055e-02]],
[[-1.2225e-01, 1.2488e-01, -3.2935e-01, ..., -3.2690e-01,
-4.9219e-01, 4.6460e-01],
[ 1.8616e-01, 1.6821e-01, -7.6675e-03, ..., -3.6224e-02,
4.6509e-01, -4.9976e-01],
[ 5.1758e-01, 5.4883e-01, -7.9004e-01, ..., 3.2275e-01,
-1.1780e-01, -9.6191e-01],
...,
[-6.4331e-02, -4.7754e-01, -8.2031e-01, ..., 1.2024e-01,
-1.6125e-01, -1.5442e-01],
[-3.6938e-01, -2.1045e-01, -5.3857e-01, ..., 1.2512e-01,
-1.1646e-01, 1.6172e+00],
[ 7.7515e-02, 4.2578e-01, 3.3789e-01, ..., 3.6377e-01,
-9.4189e-01, -8.0176e-01]],
[[-5.4736e-01, 1.9482e-01, -1.4111e+00, ..., -1.4087e-01,
7.7576e-02, -6.6833e-02],
[ 3.4082e-01, -1.1267e-01, 3.2129e-01, ..., 5.9473e-01,
-9.3896e-01, -3.3350e-01],
[ 8.4277e-01, 1.0020e+00, -8.1055e-01, ..., -2.3669e-01,
-5.0049e-01, -4.0503e-01],
...,
[-1.1273e-01, 4.9194e-02, -2.6172e-01, ..., 2.6880e-01,
3.7744e-01, -2.0447e-02],
[ 1.2832e+00, -9.6985e-02, -2.9150e-01, ..., 1.1292e-01,
-1.9116e-01, -2.1643e-01],
[ 1.8347e-01, -4.4531e-01, -3.4180e-01, ..., 1.2793e-01,
3.6011e-01, 9.5215e-01]],
[[ 1.1777e+00, -3.1174e-02, 2.1133e+00, ..., -1.7981e-01,
1.1401e-01, 2.7466e-01],
[-4.1113e-01, 1.4771e-01, -4.5264e-01, ..., -5.7080e-01,
-6.2354e-01, -2.0126e-02],
[-1.0283e+00, -7.2070e-01, 1.6321e-01, ..., -1.0547e-01,
1.8105e+00, -6.9824e-01],
...,
[-6.0352e-01, 3.2440e-02, -2.5537e-01, ..., 2.0691e-01,
-5.2277e-02, -4.4482e-01],
[-8.5840e-01, -4.8291e-01, -2.7051e-01, ..., -1.1688e-01,
-4.1113e-01, 5.1562e-01],
[ 3.0469e-01, 5.5273e-01, 1.2769e-01, ..., -3.8086e-02,
-2.3511e-01, -1.5625e-01]]]], device='cuda:0', dtype=torch.float16),)
lan added successfully saved compiled model
lan added successfully loaded compiled model
lan added tensorrt_outputs_unet=(tensor([[[[-0.0281, 0.4717, 0.1340, ..., 0.3911, -0.1110, 0.1306],
[-0.2119, 0.2698, -0.2964, ..., -0.3638, 0.0912, -0.6377],
[-0.5249, 0.0165, 0.9927, ..., -1.1055, 0.0362, -0.6802],
...,
[-0.6377, -0.2261, 0.0383, ..., 0.2847, -0.4937, 0.9146],
[ 0.5815, -0.7383, 0.0041, ..., 0.7085, -0.3379, 0.1245],
[-0.2472, -0.1250, 0.8354, ..., -0.0276, 0.1979, 0.1342]],
[[ 0.2739, 0.3992, -0.4868, ..., -0.6265, 0.0222, 0.4968],
[-0.0953, -0.7720, 0.2144, ..., 0.7710, 0.4548, -0.3210],
[ 0.1858, 0.4216, 1.1719, ..., 0.1354, -0.0334, -0.1392],
...,
[-0.0163, 0.0100, 0.1794, ..., 0.5981, 0.8042, -0.7524],
[ 0.6294, -0.6099, 0.1896, ..., -0.3352, 0.4866, 0.7021],
[ 0.2328, -0.6152, 0.7822, ..., -0.1882, -0.6387, -0.1779]],
[[-0.8218, -0.0516, -0.0839, ..., -0.6216, 0.0450, 0.1566],
[ 1.0908, -0.7075, 0.9653, ..., 0.4673, -0.2465, -0.5654],
[ 0.2251, 1.2188, 0.3413, ..., -0.3125, 0.3306, 0.1157],
...,
[ 0.1783, -0.0231, 0.0443, ..., 0.1445, 0.2466, -0.2778],
[ 0.6450, -1.7891, 0.3435, ..., 0.3984, 0.5098, 0.5015],
[ 0.2174, -0.2021, 0.3389, ..., -0.2559, 0.0767, -0.3879]],
[[-0.3618, -0.5703, -0.0786, ..., 0.4707, 0.4492, -0.2673],
[ 0.4551, 0.4832, 0.5396, ..., -0.2891, 0.0298, -0.5327],
[ 0.4585, -0.3105, -1.5898, ..., -0.1676, 0.1854, -0.3171],
...,
[-0.2598, 0.1622, -0.6060, ..., 0.2484, 0.2986, -1.1680],
[-2.1426, 0.5972, 0.6147, ..., -0.3577, -0.4790, 0.0156],
[ 0.9312, 0.5732, -0.4019, ..., 0.5850, -1.6826, -0.0422]]],
[[[-1.4150, 0.3909, -0.4822, ..., 0.2200, 0.4146, 0.0955],
[ 0.1503, 0.4656, 0.3364, ..., 0.0545, -0.1993, -0.0927],
[ 1.0791, 0.5010, 0.8813, ..., 0.0355, -1.0830, -0.7158],
...,
[-0.7690, 0.2378, -0.3633, ..., -0.0714, -0.2169, 0.0047],
[-0.3752, -0.2000, 0.0457, ..., -0.2123, 0.2108, -0.0576],
[ 0.3579, -1.3125, 0.4180, ..., -0.6763, 0.9458, -0.0367]],
[[-0.1224, 0.1234, -0.3323, ..., -0.3257, -0.4893, 0.4646],
[ 0.1893, 0.1653, -0.0038, ..., -0.0334, 0.4651, -0.5015],
[ 0.5176, 0.5493, -0.7900, ..., 0.3220, -0.1155, -0.9575],
...,
[-0.0630, -0.4766, -0.8208, ..., 0.1180, -0.1615, -0.1575],
[-0.3704, -0.2101, -0.5396, ..., 0.1234, -0.1164, 1.6182],
[ 0.0759, 0.4243, 0.3369, ..., 0.3630, -0.9458, -0.8022]],
[[-0.5449, 0.1912, -1.4131, ..., -0.1411, 0.0784, -0.0674],
[ 0.3401, -0.1118, 0.3210, ..., 0.5952, -0.9404, -0.3323],
[ 0.8384, 1.0010, -0.8042, ..., -0.2351, -0.5015, -0.4026],
...,
[-0.1104, 0.0457, -0.2615, ..., 0.2690, 0.3806, -0.0194],
[ 1.2861, -0.0950, -0.2893, ..., 0.1160, -0.1880, -0.2148],
[ 0.1829, -0.4490, -0.3406, ..., 0.1271, 0.3596, 0.9507]],
[[ 1.1777, -0.0316, 2.1074, ..., -0.1764, 0.1111, 0.2766],
[-0.4094, 0.1512, -0.4502, ..., -0.5728, -0.6245, -0.0194],
[-1.0283, -0.7241, 0.1641, ..., -0.1035, 1.8174, -0.6987],
...,
[-0.6021, 0.0345, -0.2515, ..., 0.2064, -0.0580, -0.4395],
[-0.8604, -0.4844, -0.2668, ..., -0.1143, -0.4133, 0.5127],
[ 0.3044, 0.5537, 0.1259, ..., -0.0396, -0.2357, -0.1576]]]],
device='cuda:0', dtype=torch.float16),)
Traceback (most recent call last):
File "/home/lanl/git/script/python/stable_diffusion/test_issue3163.py", line 47, in <module>
assert torch.allclose(
AssertionError: UNet results do not match
Test2) tested locally in my RTX4080 with release/2.5 branch found that if I do not do save and load the model, directly use the torch_tensorrt compiled model to inference, the UNet results does match and it can also generate the images as expected as Test1)
Test3) tested with H100 using release/2.5 docker image: docker run --gpus all --ipc=host --rm -it ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5 bash
in the docker container: apt-get install -y vim python -m pip install diffusers transformers accelerate python test_issue3163.py
It does throw me the following error:
lan added successfully saved compiled model
Traceback (most recent call last):
File "/opt/torch_tensorrt/test_issue3163.py", line 42, in <module>
loaded_unet = torch.export.load("sd_unet_compiled.ep").module()
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/__init__.py", line 473, in load
ep = deserialize(artifact, expected_opset_version)
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/serde/serialize.py", line 2437, in deserialize
.deserialize(
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/serde/serialize.py", line 2329, in deserialize
return ep.ExportedProgram(
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 700, in __init__
self.validate()
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 1117, in validate
self._validate()
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 1126, in _validate
v().check(self)
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/verifier.py", line 155, in check
self._check_graph_module(ep.graph_module)
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/verifier.py", line 214, in _check_graph_module
mod.graph.lint()
File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/graph.py", line 1549, in lint
raise RuntimeError(f'Node redefined name {node.name}!')
RuntimeError: Node redefined name getitem_130!
root@s4124-0059:/opt/torch_tensorrt# nvidia-smi
Sat Sep 21 18:38:15 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14 Driver Version: 550.54.14 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:45:00.0 Off | 0 |
| N/A 26C P0 61W / 700W | 0MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
Yes, the error in Test 3) is exactly what I'm getting on my H100. I thought the problem might be with torch.export so I already created an issue on the PyTorch repo (pytorch/pytorch#136317)
@readleyj seems like it only happens for H100, I did the exactly same test in RTX 4080 using the same image, same test code as you provided, it is working.
Test4) test with RTX 4080 using release/2.5 docker image: docker run --gpus all --ipc=host --rm -it ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5 bash
in the docker container: apt-get install -y vim python -m pip install diffusers transformers accelerate python test_issue3163.py it is Unet result is matching(rtol, atol: 1e-2, 1e-2) and able to generate the image as expected.
@lanluo-nvidia Also, on my H100 tests, the model successfully compiles, the UNet results match (using compiled_unet directly) and I can generate an image (if I use compiled_unet in place of loaded_unet). But it's saving and loading the compiled model that breaks. To me this seems like a torch.export issue but I'm not sure.
@lanluo-nvidia Any updates on this? Should I expect this issue to be resolved soon or will this be on the backlog for a while? Unfortunately, I only have H100s at my disposal and this is blocking progress for me.
@readleyj Have you tried saving with torchscript format instead of exported_program?
Simply change the two lines from
torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
loaded_unet = torch.export.load("sd_unet_compiled.ep").module()
to
torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ts", output_format="torchscript", inputs=arg_inputs_unet)
loaded_unet = torch.jit.load("sd_unet_compiled.ts").eval()
@HolyWu Thanks for the suggestion. I hadn't looked at this for a while, I just tried my original code (torch_tensorrt.save, torch.export.load) on torch 2.5.1, torch_tensorrt 2.5.0. Everything seems to be working. I can compile, save and load successfully. The problem seems resolved so I'll close the issue. Thanks a lot.