diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[torch.compile] Make HiDream torch.compile ready

Open sayakpaul opened this issue 9 months ago • 10 comments

What does this PR do?

Part of https://github.com/huggingface/diffusers/issues/11430

Trying to make the HiDream model fully compatible with torch.compile() but it fails with: https://pastebin.com/EbCFqBvw

To reproduce run the following from a GPU machine:

RUN_COMPILE=1 RUN_SLOW=1 pytest tests/models/transformers/test_models_transformer_hidream.py -k "test_torch_compile_recompilation_and_graph_break"

I am on the following env:

- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.7.0+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.51.3
- Accelerate version: 1.6.0.dev0
- PEFT version: 0.15.2.dev0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

@anijain2305 @StrongerXi would you have any pointers?

sayakpaul avatar May 01 '25 13:05 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

The graph break seems to be induced by @torch.no_grad: https://github.com/huggingface/diffusers/blob/d0c02398b986f2876b2b79f3a137ed00a7edde35/src/diffusers/models/transformers/transformer_hidream_image.py#L388

@anijain2305 is this known?

StrongerXi avatar May 01 '25 17:05 StrongerXi

The graph break seems to be induced by @torch.no_grad:

https://github.com/huggingface/diffusers/blob/d0c02398b986f2876b2b79f3a137ed00a7edde35/src/diffusers/models/transformers/transformer_hidream_image.py#L388

@anijain2305 is this known?

Even if we remove the decorator, it still fails with the same error.

sayakpaul avatar May 03 '25 05:05 sayakpaul

Thanks! Appreciate it.

On Thu, 8 May 2025 at 7:06 PM, Animesh Jain @.***> wrote:

@.**** approved this pull request.

LGTM

— Reply to this email directly, view it on GitHub https://github.com/huggingface/diffusers/pull/11477#pullrequestreview-2825169709, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFPE2TCL5CXY6KDTOMUTXM325NMWRAVCNFSM6AAAAAB4H4U5SOVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDQMRVGE3DSNZQHE . You are receiving this because you authored the thread.Message ID: @.***>

sayakpaul avatar May 08 '25 13:05 sayakpaul

Not a useful update. But there seems to be some dynamic shapes graph break here coming from moe_infer function.

cc @laithsakka

anijain2305 avatar May 13 '25 22:05 anijain2305

@anijain2305

Okay I think I know why this is happening. The line that primarily causes this shape change is: https://github.com/huggingface/diffusers/blob/f4fa3beee7f49b80ce7a58f9c8002f43299175c9/src/diffusers/models/transformers/transformer_hidream_image.py#L897

This is why the moe_infer() function, when called with single_stream_blocks, complains about the shape changes.

So, I tried with dynamic=True along with

torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.fx.experimental._config.use_duck_shape = False

It then complains:

msg = 'dynamic shape operator: aten.bincount.default; Operator does not have a meta kernel that supports dynamic output shapes, please report an issue to PyTorch'

Keeping this open maybe for better tracking.

sayakpaul avatar May 14 '25 15:05 sayakpaul

Cc: @StrongerXi for the above observation too.

sayakpaul avatar May 17 '25 08:05 sayakpaul

On it.

StrongerXi avatar Jun 05 '25 22:06 StrongerXi

Okay I spent some time digging into the MOE stuff, here's what I learned:

  1. HiDream has 2 branches in the MOE FFN layer, and looks like the moe_infer branch is meant to speed up inference as it explicitly skips the experts without tokens. However, that's really bad for torch.compile because (a). it creates a hard-to-resolve data dependency (the branching depends on output of torch.bincount which depends on the data of flat_expert_indices, and (b). even if we solve (a), we'd face lots of recompilations, because torch.compile would compile for each possible execution path (e.g., expert 1 & 3 firing, or expert 1, 2, 4 firing, etc.). https://github.com/huggingface/diffusers/blob/dacae33c1b56c6f5f9292d44b7bafb95db96566b/src/diffusers/models/transformers/transformer_hidream_image.py#L396-L399

  2. Then I did some benchmark, and turns out the moe_infer isn't faster than the "training branch", and they produce identical output images, and torch.compile produces much lower e2e latency using the "training branch": https://github.com/huggingface/diffusers/blob/dacae33c1b56c6f5f9292d44b7bafb95db96566b/src/diffusers/models/transformers/transformer_hidream_image.py#L375-L382

Then I just have to fix a small graph break here, where img_sizes is supposed to be a List[Tuple[int, int]] but got computed as tensors: https://github.com/huggingface/diffusers/blob/dacae33c1b56c6f5f9292d44b7bafb95db96566b/src/diffusers/models/transformers/transformer_hidream_image.py#L686 https://github.com/huggingface/diffusers/blob/dacae33c1b56c6f5f9292d44b7bafb95db96566b/src/diffusers/models/transformers/transformer_hidream_image.py#L718-L720

The fix is simple:

        # create img_sizes
        #img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
        #img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
        img_sizes = [[patch_height, patch_width]] * batch_size

Here are the e2e pipeline benchmark results using the hidream demo script, and compiling the transformer:

# pytorch 2.7.1
#
# original eager:     26.6s, compiled 24.8s (fullgraph=False)
# train-branch eager: 25.9s, compiled 19.5s (fullgraph=True)

I also saw that ComfyUI uses the training branch too. So maybe we should just use the training branch in eager as well? Or we could add a torch.compiler.is_compiling() to use the training branch under compile only. What do you think @sayakpaul?

StrongerXi avatar Jun 16 '25 21:06 StrongerXi

Wow, this is terrific KT. Thanks, Ryan!

Or we could add a torch.compiler.is_compiling() to use the training branch under compile only.

This is a good approach and is worth adding. @yiyixuxu what are your thoughts?

sayakpaul avatar Jun 17 '25 04:06 sayakpaul