TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] Can't load UNet on H100 after compiling ExportedProgram with torch_tensorrt.dynamo.compile and saving

Open readleyj opened this issue 1 year ago • 13 comments

Bug Description

I am trying to use torch_tensorrt.dynamo.compile() to AOT compile the UNet portion of a StableDiffusionPipeline from the diffusers library (version 0.30.2). I am able to export the UNet with torch.export.export(), compile it with torch_tensorrt.dynamo.compile() and save it with torch_tensorrt.save(). However, I am encountering a runtime error when attempting to load the saved compiled UNet with torch.export.load().

To Reproduce

Run the code below

import functools

import torch
import torch_tensorrt

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

def generate_sd_unet_inputs():
    sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
    timestep = torch.rand([], device="cuda", dtype=torch.float32) * 999
    encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
    
    return sample, timestep, encoder_hidden_states

with torch.inference_mode():
    pipe = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
    ).to("cuda")
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)

    unet_model = pipe.unet.eval()
    unet_model.forward = functools.partial(unet_model.forward, return_dict=False)
    
    arg_inputs_unet = generate_sd_unet_inputs()
    expected_outputs_unet = unet_model(*arg_inputs_unet)
    
    unet_exported_program = torch.export.export(unet_model, arg_inputs_unet)
        
    with torch_tensorrt.logging.errors():
        compiled_unet = torch_tensorrt.dynamo.compile(
            unet_exported_program,
            inputs=arg_inputs_unet,
            enabled_precisions={torch.float16, torch.float32},
            truncate_double=True,
        )
    
    torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
    loaded_unet = torch.export.load("sd_unet_compiled.ep").module()

Error message

...
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:370: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  engine_node = gm.graph.get_attr(engine_name)

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_2_engine target _run_on_acc_2_engine _run_on_acc_2_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_4_engine target _run_on_acc_4_engine _run_on_acc_4_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_6_engine target _run_on_acc_6_engine _run_on_acc_6_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_8_engine target _run_on_acc_8_engine _run_on_acc_8_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

WARNING:py.warnings:/home/ismayilismayilov/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1593: UserWarning: Additional 16 warnings suppressed about get_attr references
  warnings.warn(

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 48
     40     compiled_unet = torch_tensorrt.dynamo.compile(
     41         unet_exported_program,
     42         inputs=arg_inputs_unet,
     43         enabled_precisions={torch.float16, torch.float32},
     44         truncate_double=True,
     45     )
     47 torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
---> 48 loaded_unet = torch.export.load("sd_unet_compiled.ep")

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py:476](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/__init__.py#line=475), in load(f, extra_files, expected_opset_version)
    468 artifact: SerializedArtifact = SerializedArtifact(
    469     serialized_exported_program,
    470     serialized_state_dict,
    471     serialized_constants,
    472     serialized_example_inputs,
    473 )
    475 # Deserialize ExportedProgram
--> 476 ep = deserialize(artifact, expected_opset_version)
    478 return ep

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2437](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py#line=2436), in deserialize(artifact, expected_opset_version)
   2433 exported_program_dict = json.loads(exported_program_str)
   2434 serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict)
   2435 return (
   2436     ExportedProgramDeserializer(expected_opset_version)
-> 2437     .deserialize(
   2438         serialized_exported_program,
   2439         artifact.state_dict,
   2440         artifact.constants,
   2441         artifact.example_inputs,
   2442     )
   2443 )

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py:2329](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/serde/serialize.py#line=2328), in ExportedProgramDeserializer.deserialize(self, exported_program, state_dict, constants, example_inputs)
   2314 res = (
   2315     GraphModuleDeserializer()
   2316     .deserialize(
   (...)
   2322     )
   2323 )
   2324 range_constraints = self.deserialize_range_constraints(
   2325     symbol_name_to_range,
   2326     res.names_to_symbols,
   2327 )
-> 2329 return ep.ExportedProgram(
   2330     root=res.graph_module,
   2331     graph=res.graph_module.graph,
   2332     graph_signature=res.signature,
   2333     state_dict=res.state_dict,  # type: ignore[arg-type]
   2334     range_constraints=range_constraints,
   2335     module_call_graph=res.module_call_graph,
   2336     example_inputs=res.example_inputs,
   2337     constants=res.constants,
   2338     verifiers=[load_verifier(v) for v in exported_program.verifiers],
   2339 )

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:700](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=699), in ExportedProgram.__init__(self, root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, constants, verifiers)
    698 self._verifiers = verifiers
    699 # Validate should be always the last step of the constructor.
--> 700 self.validate()

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1117](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=1116), in ExportedProgram.validate(self)
   1115 @compatibility(is_backward_compatible=False)
   1116 def validate(self):
-> 1117     self._validate()

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py:1126](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/export/exported_program.py#line=1125), in ExportedProgram._validate(self)
   1122 assert (
   1123     len(self.verifiers) > 0
   1124 ), "ExportedProgram must have at least one verifier."
   1125 for v in self.verifiers:
-> 1126     v().check(self)

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:155](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py#line=154), in Verifier.check(self, ep)
    153 @final
    154 def check(self, ep: "ExportedProgram") -> None:
--> 155     self._check_graph_module(ep.graph_module)
    156     _verify_exported_program_module_call_graph(ep)
    157     _verify_exported_program_signature(ep)

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py:214](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/_export/verifier.py#line=213), in Verifier._check_graph_module(self, gm)
    211 if not isinstance(mod, torch.fx.GraphModule):
    212     continue
--> 214 mod.graph.lint()
    215 for node in mod.graph.nodes:
    216     # TODO(T140410192): should have fake tensor for all dialects
    217     if node.op in {"call_module", "call_method"}:

File [~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py:1549](https://jhub.codeway-int.com/user/ismayilismayilov/lab/tree/ismayilismayilov/~/.conda/envs/torch-tensorrt-nightly/lib/python3.11/site-packages/torch/fx/graph.py#line=1548), in Graph.lint(self)
   1546     seen_values.add(node)
   1548     if node.name in seen_names:
-> 1549         raise RuntimeError(f'Node redefined name {node.name}!')
   1550     seen_names.add(node.name)
   1552 # Check targets are legit

RuntimeError: Node redefined name getitem_130!

Expected behavior

The code should load the saved compiled model without erroring out.

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240912+cu124
  • PyTorch Version (e.g. 1.0): 2.5.0.dev20240912+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 22.04.4 LTS (x86_64)
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11.10
  • CUDA version: 12.4
  • GPU models and configuration: 1/2 of an H100 (Configured with MIG)
  • Any other relevant information: Using diffusers version 0.30.2

Additional context

I have to use functools.partial() in the code above because the default output of the pipeline's forward method is the UNet2DConditionOutput dataclass. I tried to get rid of functools.partial() by instead using torch.export.register_dataclass() but was met with the same runtime error mentioned above.

Additionally, saving and loading the ExportedProgram (without Torch-TensorRT compilation) works fine.

readleyj avatar Sep 16 '24 12:09 readleyj

@readleyj I tried in my environment from today's latest main branch using RTX 4080, I don't get the error as you pasted. I can successfully load the unet.

lanluo-nvidia avatar Sep 18 '24 17:09 lanluo-nvidia

@lanluo-nvidia Thank you for the reply. That is very strange. I will try with today's nightly and report back. Also, I am running this on an H100, could that possibly be the source of the issue?

readleyj avatar Sep 18 '24 17:09 readleyj

I tried again with today's nightly (torch_tensorrt==2.5.0.dev20240918+cu124, torch==dev20240912+cu124) and I am encountering the same runtime error. Additionally, the results for the compiled UNet match the original UNet. At this point, I am not sure if the issue is with Torch-TensorRT or torch.export.

readleyj avatar Sep 18 '24 18:09 readleyj

I also tried with release 2.4. There, I can successfully save and load the model but the compiled model outputs are full of nans. In general, Stable Diffusion with Torch-TensorRT seems very problematic.

readleyj avatar Sep 20 '24 08:09 readleyj

@readleyj yes, we have bugs in release 2.4 which got fixed in current main branch, if you could paste the code: after loaded the unet how do you generate the image. I will give a try also.

lanluo-nvidia avatar Sep 20 '24 23:09 lanluo-nvidia

@lanluo-nvidia After loading the UNet, I first check if the results match (expected_outputs_unet is defined in the previous code block)

with torch.inference_mode():    
    tensorrt_outputs_unet = loaded_unet(*arg_inputs_unet)
    for expected_output, tensorrt_output in zip(expected_outputs_unet, tensorrt_outputs_unet):
        assert torch.allclose(
            expected_output, tensorrt_output, 1e-2, 1e-2
        ), "UNet results do not match"
    
    print("UNet results match for Torch-TensorRT and Diffusers")

To generate an image, I plug the loaded UNet into a StableDiffusion pipeline as follows (code block assumes loaded_unet is already defined):

import torch
import torch_tensorrt
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

PROMPT = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

with torch.inference_mode():
    pipe = StableDiffusionPipeline.from_pretrained(
        "CompVis/stable-diffusion-v1-4",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True
    ).to("cuda")
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    
    class LoadedUNet(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.in_channels = pipe.unet.config.in_channels
            setattr(self, "config", pipe.unet.config)
            self.device = pipe.unet.device
    
        def forward(self, latent_model_input, t, encoder_hidden_states, **kwargs):
            sample = loaded_unet(latent_model_input, t, encoder_hidden_states)
            return sample
    
    pipe.unet = LoadedUNet()
    
    image = pipe(PROMPT,
                 num_inference_steps=50,
                 height=512,
                 width=512,
            ).images[0]

Note that you may receive a Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and / or seed. warning from diffusers. This happens when the image is all black or gibberish.

readleyj avatar Sep 21 '24 00:09 readleyj

@readleyj I have tried with release/2.5 branch: (this is our upcoming release branch and it is more stable then main branch since main branch is getting all the latest changes from both pytorch and torch_tensorrt) Test 1): tested locally in my RTX4080 with release/2.5 branch

python -m pip install --pre --editable . --extra-index-url https://download.pytorch.org/whl/test/cu124
python -m pip install --pre “torchvision>=0.20.0,<0.21.0” --index-url https://download.pytorch.org/whl/test/cu124
python /home/lanl/git/script/python/stable_diffusion/test_issue3163.py

I can see it does throw out the results does not match error (the rtol atol is actually very close to 1e -2, 1e-2), also does not see full of nans and also, if I change the rtol, atol to 1e-2, 5e-2, it is able to generate the image as expected:

lan added expected_outputs_unet=(tensor([[[[-3.0304e-02,  4.7119e-01,  1.3281e-01,  ...,  3.9185e-01,
           -1.1200e-01,  1.3062e-01],
          [-2.1191e-01,  2.7026e-01, -2.9688e-01,  ..., -3.6548e-01,
            9.2712e-02, -6.3867e-01],
          [-5.2441e-01,  1.8555e-02,  9.9414e-01,  ..., -1.0977e+00,
            3.9429e-02, -6.8164e-01],
          ...,
          [-6.3574e-01, -2.2583e-01,  3.3234e-02,  ...,  2.8491e-01,
           -4.9316e-01,  9.1455e-01],
          [ 5.8594e-01, -7.3779e-01,  6.9695e-03,  ...,  7.1094e-01,
           -3.3569e-01,  1.2830e-01],
          [-2.4805e-01, -1.2152e-01,  8.3643e-01,  ..., -2.8641e-02,
            1.9739e-01,  1.3367e-01]],

         [[ 2.7319e-01,  4.0063e-01, -4.8682e-01,  ..., -6.2939e-01,
            2.1790e-02,  4.9634e-01],
          [-9.3323e-02, -7.7393e-01,  2.1399e-01,  ...,  7.6953e-01,
            4.5410e-01, -3.1909e-01],
          [ 1.8079e-01,  4.2017e-01,  1.1699e+00,  ...,  1.3843e-01,
           -3.2898e-02, -1.3953e-01],
          ...,
          [-1.4839e-02,  1.2131e-02,  1.7859e-01,  ...,  5.9717e-01,
            8.0762e-01, -7.5684e-01],
          [ 6.3232e-01, -6.1035e-01,  1.9214e-01,  ..., -3.3496e-01,
            4.9048e-01,  7.0166e-01],
          [ 2.3340e-01, -6.1279e-01,  7.8271e-01,  ..., -1.9067e-01,
           -6.3965e-01, -1.7529e-01]],

         [[-8.2324e-01, -4.7180e-02, -8.0383e-02,  ..., -6.2109e-01,
            4.5319e-02,  1.5930e-01],
          [ 1.0908e+00, -7.1143e-01,  9.6484e-01,  ...,  4.6777e-01,
           -2.4548e-01, -5.6445e-01],
          [ 2.2278e-01,  1.2256e+00,  3.4302e-01,  ..., -3.1372e-01,
            3.3203e-01,  1.1426e-01],
          ...,
          [ 1.7578e-01, -2.4002e-02,  3.9581e-02,  ...,  1.4160e-01,
            2.4902e-01, -2.7515e-01],
          [ 6.4893e-01, -1.7891e+00,  3.4570e-01,  ...,  3.9868e-01,
            5.0977e-01,  5.0146e-01],
          [ 2.1948e-01, -2.0020e-01,  3.3862e-01,  ..., -2.5488e-01,
            7.9346e-02, -3.8794e-01]],

         [[-3.6206e-01, -5.7080e-01, -7.8369e-02,  ...,  4.7388e-01,
            4.5093e-01, -2.6636e-01],
          [ 4.5630e-01,  4.8340e-01,  5.4053e-01,  ..., -2.9175e-01,
            2.3331e-02, -5.2979e-01],
          [ 4.5728e-01, -3.1177e-01, -1.5879e+00,  ..., -1.6748e-01,
            1.8408e-01, -3.1592e-01],
          ...,
          [-2.6074e-01,  1.6028e-01, -5.9766e-01,  ...,  2.4963e-01,
            2.9688e-01, -1.1699e+00],
          [-2.1367e+00,  5.9619e-01,  6.1133e-01,  ..., -3.5962e-01,
           -4.8193e-01,  1.5167e-02],
          [ 9.3018e-01,  5.7471e-01, -4.0332e-01,  ...,  5.8691e-01,
           -1.6826e+00, -4.2450e-02]]],


        [[[-1.4170e+00,  3.9453e-01, -4.8438e-01,  ...,  2.2180e-01,
            4.1724e-01,  9.6252e-02],
          [ 1.5015e-01,  4.6851e-01,  3.3643e-01,  ...,  5.3467e-02,
           -1.9666e-01, -9.2773e-02],
          [ 1.0840e+00,  5.0244e-01,  8.7695e-01,  ...,  3.6957e-02,
           -1.0840e+00, -7.1436e-01],
          ...,
          [-7.7100e-01,  2.4207e-01, -3.6084e-01,  ..., -6.8298e-02,
           -2.1643e-01,  1.4391e-03],
          [-3.7964e-01, -2.0032e-01,  4.6173e-02,  ..., -2.1252e-01,
            2.0972e-01, -6.0608e-02],
          [ 3.5840e-01, -1.3125e+00,  4.1528e-01,  ..., -6.7871e-01,
            9.4434e-01, -3.8055e-02]],

         [[-1.2225e-01,  1.2488e-01, -3.2935e-01,  ..., -3.2690e-01,
           -4.9219e-01,  4.6460e-01],
          [ 1.8616e-01,  1.6821e-01, -7.6675e-03,  ..., -3.6224e-02,
            4.6509e-01, -4.9976e-01],
          [ 5.1758e-01,  5.4883e-01, -7.9004e-01,  ...,  3.2275e-01,
           -1.1780e-01, -9.6191e-01],
          ...,
          [-6.4331e-02, -4.7754e-01, -8.2031e-01,  ...,  1.2024e-01,
           -1.6125e-01, -1.5442e-01],
          [-3.6938e-01, -2.1045e-01, -5.3857e-01,  ...,  1.2512e-01,
           -1.1646e-01,  1.6172e+00],
          [ 7.7515e-02,  4.2578e-01,  3.3789e-01,  ...,  3.6377e-01,
           -9.4189e-01, -8.0176e-01]],

         [[-5.4736e-01,  1.9482e-01, -1.4111e+00,  ..., -1.4087e-01,
            7.7576e-02, -6.6833e-02],
          [ 3.4082e-01, -1.1267e-01,  3.2129e-01,  ...,  5.9473e-01,
           -9.3896e-01, -3.3350e-01],
          [ 8.4277e-01,  1.0020e+00, -8.1055e-01,  ..., -2.3669e-01,
           -5.0049e-01, -4.0503e-01],
          ...,
          [-1.1273e-01,  4.9194e-02, -2.6172e-01,  ...,  2.6880e-01,
            3.7744e-01, -2.0447e-02],
          [ 1.2832e+00, -9.6985e-02, -2.9150e-01,  ...,  1.1292e-01,
           -1.9116e-01, -2.1643e-01],
          [ 1.8347e-01, -4.4531e-01, -3.4180e-01,  ...,  1.2793e-01,
            3.6011e-01,  9.5215e-01]],

         [[ 1.1777e+00, -3.1174e-02,  2.1133e+00,  ..., -1.7981e-01,
            1.1401e-01,  2.7466e-01],
          [-4.1113e-01,  1.4771e-01, -4.5264e-01,  ..., -5.7080e-01,
           -6.2354e-01, -2.0126e-02],
          [-1.0283e+00, -7.2070e-01,  1.6321e-01,  ..., -1.0547e-01,
            1.8105e+00, -6.9824e-01],
          ...,
          [-6.0352e-01,  3.2440e-02, -2.5537e-01,  ...,  2.0691e-01,
           -5.2277e-02, -4.4482e-01],
          [-8.5840e-01, -4.8291e-01, -2.7051e-01,  ..., -1.1688e-01,
           -4.1113e-01,  5.1562e-01],
          [ 3.0469e-01,  5.5273e-01,  1.2769e-01,  ..., -3.8086e-02,
           -2.3511e-01, -1.5625e-01]]]], device='cuda:0', dtype=torch.float16),)
lan added successfully saved compiled model
lan added successfully loaded compiled model
lan added tensorrt_outputs_unet=(tensor([[[[-0.0281,  0.4717,  0.1340,  ...,  0.3911, -0.1110,  0.1306],
          [-0.2119,  0.2698, -0.2964,  ..., -0.3638,  0.0912, -0.6377],
          [-0.5249,  0.0165,  0.9927,  ..., -1.1055,  0.0362, -0.6802],
          ...,
          [-0.6377, -0.2261,  0.0383,  ...,  0.2847, -0.4937,  0.9146],
          [ 0.5815, -0.7383,  0.0041,  ...,  0.7085, -0.3379,  0.1245],
          [-0.2472, -0.1250,  0.8354,  ..., -0.0276,  0.1979,  0.1342]],

         [[ 0.2739,  0.3992, -0.4868,  ..., -0.6265,  0.0222,  0.4968],
          [-0.0953, -0.7720,  0.2144,  ...,  0.7710,  0.4548, -0.3210],
          [ 0.1858,  0.4216,  1.1719,  ...,  0.1354, -0.0334, -0.1392],
          ...,
          [-0.0163,  0.0100,  0.1794,  ...,  0.5981,  0.8042, -0.7524],
          [ 0.6294, -0.6099,  0.1896,  ..., -0.3352,  0.4866,  0.7021],
          [ 0.2328, -0.6152,  0.7822,  ..., -0.1882, -0.6387, -0.1779]],

         [[-0.8218, -0.0516, -0.0839,  ..., -0.6216,  0.0450,  0.1566],
          [ 1.0908, -0.7075,  0.9653,  ...,  0.4673, -0.2465, -0.5654],
          [ 0.2251,  1.2188,  0.3413,  ..., -0.3125,  0.3306,  0.1157],
          ...,
          [ 0.1783, -0.0231,  0.0443,  ...,  0.1445,  0.2466, -0.2778],
          [ 0.6450, -1.7891,  0.3435,  ...,  0.3984,  0.5098,  0.5015],
          [ 0.2174, -0.2021,  0.3389,  ..., -0.2559,  0.0767, -0.3879]],

         [[-0.3618, -0.5703, -0.0786,  ...,  0.4707,  0.4492, -0.2673],
          [ 0.4551,  0.4832,  0.5396,  ..., -0.2891,  0.0298, -0.5327],
          [ 0.4585, -0.3105, -1.5898,  ..., -0.1676,  0.1854, -0.3171],
          ...,
          [-0.2598,  0.1622, -0.6060,  ...,  0.2484,  0.2986, -1.1680],
          [-2.1426,  0.5972,  0.6147,  ..., -0.3577, -0.4790,  0.0156],
          [ 0.9312,  0.5732, -0.4019,  ...,  0.5850, -1.6826, -0.0422]]],


        [[[-1.4150,  0.3909, -0.4822,  ...,  0.2200,  0.4146,  0.0955],
          [ 0.1503,  0.4656,  0.3364,  ...,  0.0545, -0.1993, -0.0927],
          [ 1.0791,  0.5010,  0.8813,  ...,  0.0355, -1.0830, -0.7158],
          ...,
          [-0.7690,  0.2378, -0.3633,  ..., -0.0714, -0.2169,  0.0047],
          [-0.3752, -0.2000,  0.0457,  ..., -0.2123,  0.2108, -0.0576],
          [ 0.3579, -1.3125,  0.4180,  ..., -0.6763,  0.9458, -0.0367]],

         [[-0.1224,  0.1234, -0.3323,  ..., -0.3257, -0.4893,  0.4646],
          [ 0.1893,  0.1653, -0.0038,  ..., -0.0334,  0.4651, -0.5015],
          [ 0.5176,  0.5493, -0.7900,  ...,  0.3220, -0.1155, -0.9575],
          ...,
          [-0.0630, -0.4766, -0.8208,  ...,  0.1180, -0.1615, -0.1575],
          [-0.3704, -0.2101, -0.5396,  ...,  0.1234, -0.1164,  1.6182],
          [ 0.0759,  0.4243,  0.3369,  ...,  0.3630, -0.9458, -0.8022]],

         [[-0.5449,  0.1912, -1.4131,  ..., -0.1411,  0.0784, -0.0674],
          [ 0.3401, -0.1118,  0.3210,  ...,  0.5952, -0.9404, -0.3323],
          [ 0.8384,  1.0010, -0.8042,  ..., -0.2351, -0.5015, -0.4026],
          ...,
          [-0.1104,  0.0457, -0.2615,  ...,  0.2690,  0.3806, -0.0194],
          [ 1.2861, -0.0950, -0.2893,  ...,  0.1160, -0.1880, -0.2148],
          [ 0.1829, -0.4490, -0.3406,  ...,  0.1271,  0.3596,  0.9507]],

         [[ 1.1777, -0.0316,  2.1074,  ..., -0.1764,  0.1111,  0.2766],
          [-0.4094,  0.1512, -0.4502,  ..., -0.5728, -0.6245, -0.0194],
          [-1.0283, -0.7241,  0.1641,  ..., -0.1035,  1.8174, -0.6987],
          ...,
          [-0.6021,  0.0345, -0.2515,  ...,  0.2064, -0.0580, -0.4395],
          [-0.8604, -0.4844, -0.2668,  ..., -0.1143, -0.4133,  0.5127],
          [ 0.3044,  0.5537,  0.1259,  ..., -0.0396, -0.2357, -0.1576]]]],
       device='cuda:0', dtype=torch.float16),)
Traceback (most recent call last):
  File "/home/lanl/git/script/python/stable_diffusion/test_issue3163.py", line 47, in <module>
    assert torch.allclose(
AssertionError: UNet results do not match

lanluo-nvidia avatar Sep 21 '24 18:09 lanluo-nvidia

Test2) tested locally in my RTX4080 with release/2.5 branch found that if I do not do save and load the model, directly use the torch_tensorrt compiled model to inference, the UNet results does match and it can also generate the images as expected as Test1)

lanluo-nvidia avatar Sep 21 '24 18:09 lanluo-nvidia

Test3) tested with H100 using release/2.5 docker image: docker run --gpus all --ipc=host --rm -it ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5 bash

in the docker container: apt-get install -y vim python -m pip install diffusers transformers accelerate python test_issue3163.py

It does throw me the following error:

lan added successfully saved compiled model
Traceback (most recent call last):
  File "/opt/torch_tensorrt/test_issue3163.py", line 42, in <module>
    loaded_unet = torch.export.load("sd_unet_compiled.ep").module()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/__init__.py", line 473, in load
    ep = deserialize(artifact, expected_opset_version)
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/serde/serialize.py", line 2437, in deserialize
    .deserialize(
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/serde/serialize.py", line 2329, in deserialize
    return ep.ExportedProgram(
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 700, in __init__
    self.validate()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 1117, in validate
    self._validate()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/export/exported_program.py", line 1126, in _validate
    v().check(self)
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/verifier.py", line 155, in check
    self._check_graph_module(ep.graph_module)
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/_export/verifier.py", line 214, in _check_graph_module
    mod.graph.lint()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/graph.py", line 1549, in lint
    raise RuntimeError(f'Node redefined name {node.name}!')
RuntimeError: Node redefined name getitem_130!
root@s4124-0059:/opt/torch_tensorrt# nvidia-smi
Sat Sep 21 18:38:15 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:45:00.0 Off |                    0 |
| N/A   26C    P0             61W /  700W |       0MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |

lanluo-nvidia avatar Sep 21 '24 18:09 lanluo-nvidia

Yes, the error in Test 3) is exactly what I'm getting on my H100. I thought the problem might be with torch.export so I already created an issue on the PyTorch repo (pytorch/pytorch#136317)

readleyj avatar Sep 21 '24 19:09 readleyj

@readleyj seems like it only happens for H100, I did the exactly same test in RTX 4080 using the same image, same test code as you provided, it is working.

Test4) test with RTX 4080 using release/2.5 docker image: docker run --gpus all --ipc=host --rm -it ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.5 bash

in the docker container: apt-get install -y vim python -m pip install diffusers transformers accelerate python test_issue3163.py it is Unet result is matching(rtol, atol: 1e-2, 1e-2) and able to generate the image as expected.

lanluo-nvidia avatar Sep 21 '24 19:09 lanluo-nvidia

@lanluo-nvidia Also, on my H100 tests, the model successfully compiles, the UNet results match (using compiled_unet directly) and I can generate an image (if I use compiled_unet in place of loaded_unet). But it's saving and loading the compiled model that breaks. To me this seems like a torch.export issue but I'm not sure.

readleyj avatar Sep 21 '24 20:09 readleyj

@lanluo-nvidia Any updates on this? Should I expect this issue to be resolved soon or will this be on the backlog for a while? Unfortunately, I only have H100s at my disposal and this is blocking progress for me.

readleyj avatar Sep 26 '24 07:09 readleyj

@readleyj Have you tried saving with torchscript format instead of exported_program?

Simply change the two lines from

    torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
    loaded_unet = torch.export.load("sd_unet_compiled.ep").module()

to

    torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ts", output_format="torchscript", inputs=arg_inputs_unet)
    loaded_unet = torch.jit.load("sd_unet_compiled.ts").eval()

HolyWu avatar Nov 24 '24 06:11 HolyWu

@HolyWu Thanks for the suggestion. I hadn't looked at this for a while, I just tried my original code (torch_tensorrt.save, torch.export.load) on torch 2.5.1, torch_tensorrt 2.5.0. Everything seems to be working. I can compile, save and load successfully. The problem seems resolved so I'll close the issue. Thanks a lot.

readleyj avatar Nov 26 '24 19:11 readleyj