[torch.compile] Make HiDream torch.compile ready
What does this PR do?
Part of https://github.com/huggingface/diffusers/issues/11430
Trying to make the HiDream model fully compatible with torch.compile() but it fails with:
https://pastebin.com/EbCFqBvw
To reproduce run the following from a GPU machine:
RUN_COMPILE=1 RUN_SLOW=1 pytest tests/models/transformers/test_models_transformer_hidream.py -k "test_torch_compile_recompilation_and_graph_break"
I am on the following env:
- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.7.0+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.51.3
- Accelerate version: 1.6.0.dev0
- PEFT version: 0.15.2.dev0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
@anijain2305 @StrongerXi would you have any pointers?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
The graph break seems to be induced by @torch.no_grad: https://github.com/huggingface/diffusers/blob/d0c02398b986f2876b2b79f3a137ed00a7edde35/src/diffusers/models/transformers/transformer_hidream_image.py#L388
@anijain2305 is this known?
The graph break seems to be induced by
@torch.no_grad:https://github.com/huggingface/diffusers/blob/d0c02398b986f2876b2b79f3a137ed00a7edde35/src/diffusers/models/transformers/transformer_hidream_image.py#L388
@anijain2305 is this known?
Even if we remove the decorator, it still fails with the same error.
Thanks! Appreciate it.
On Thu, 8 May 2025 at 7:06 PM, Animesh Jain @.***> wrote:
@.**** approved this pull request.
LGTM
— Reply to this email directly, view it on GitHub https://github.com/huggingface/diffusers/pull/11477#pullrequestreview-2825169709, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFPE2TCL5CXY6KDTOMUTXM325NMWRAVCNFSM6AAAAAB4H4U5SOVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDQMRVGE3DSNZQHE . You are receiving this because you authored the thread.Message ID: @.***>
Not a useful update. But there seems to be some dynamic shapes graph break here coming from moe_infer function.
cc @laithsakka
@anijain2305
Okay I think I know why this is happening. The line that primarily causes this shape change is: https://github.com/huggingface/diffusers/blob/f4fa3beee7f49b80ce7a58f9c8002f43299175c9/src/diffusers/models/transformers/transformer_hidream_image.py#L897
This is why the moe_infer() function, when called with single_stream_blocks, complains about the shape changes.
So, I tried with dynamic=True along with
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.fx.experimental._config.use_duck_shape = False
It then complains:
msg = 'dynamic shape operator: aten.bincount.default; Operator does not have a meta kernel that supports dynamic output shapes, please report an issue to PyTorch'
Keeping this open maybe for better tracking.
Cc: @StrongerXi for the above observation too.
On it.
Okay I spent some time digging into the MOE stuff, here's what I learned:
-
HiDream has 2 branches in the MOE FFN layer, and looks like the
moe_inferbranch is meant to speed up inference as it explicitly skips the experts without tokens. However, that's really bad fortorch.compilebecause (a). it creates a hard-to-resolve data dependency (the branching depends on output oftorch.bincountwhich depends on the data offlat_expert_indices, and (b). even if we solve (a), we'd face lots of recompilations, becausetorch.compilewould compile for each possible execution path (e.g., expert 1 & 3 firing, or expert 1, 2, 4 firing, etc.). https://github.com/huggingface/diffusers/blob/dacae33c1b56c6f5f9292d44b7bafb95db96566b/src/diffusers/models/transformers/transformer_hidream_image.py#L396-L399 -
Then I did some benchmark, and turns out the
moe_inferisn't faster than the "training branch", and they produce identical output images, andtorch.compileproduces much lower e2e latency using the "training branch": https://github.com/huggingface/diffusers/blob/dacae33c1b56c6f5f9292d44b7bafb95db96566b/src/diffusers/models/transformers/transformer_hidream_image.py#L375-L382
Then I just have to fix a small graph break here, where img_sizes is supposed to be a List[Tuple[int, int]] but got computed as tensors:
https://github.com/huggingface/diffusers/blob/dacae33c1b56c6f5f9292d44b7bafb95db96566b/src/diffusers/models/transformers/transformer_hidream_image.py#L686 https://github.com/huggingface/diffusers/blob/dacae33c1b56c6f5f9292d44b7bafb95db96566b/src/diffusers/models/transformers/transformer_hidream_image.py#L718-L720
The fix is simple:
# create img_sizes
#img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
#img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
img_sizes = [[patch_height, patch_width]] * batch_size
Here are the e2e pipeline benchmark results using the hidream demo script, and compiling the transformer:
# pytorch 2.7.1
#
# original eager: 26.6s, compiled 24.8s (fullgraph=False)
# train-branch eager: 25.9s, compiled 19.5s (fullgraph=True)
I also saw that ComfyUI uses the training branch too. So maybe we should just use the training branch in eager as well? Or we could add a torch.compiler.is_compiling() to use the training branch under compile only. What do you think @sayakpaul?
Wow, this is terrific KT. Thanks, Ryan!
Or we could add a torch.compiler.is_compiling() to use the training branch under compile only.
This is a good approach and is worth adding. @yiyixuxu what are your thoughts?