diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Add Unified Sequence Parallel attention

Open Bissmella opened this issue 3 months ago • 19 comments

What does this PR do?

This is a draft implementation of the Unified SP attention approach.

  • Implements _all_to_all_dim_exchange with custom scatter and gather indices
  • Implements TemplatedUnifiedAttention

Core implementation complete, needs:

  • [x] Testing
  • [x] Validation

Bissmella avatar Nov 21 '25 08:11 Bissmella

It would be nice to get a testing script so that we can quickly check things.

sayakpaul avatar Nov 21 '25 09:11 sayakpaul

I added a basic test script with a simple forward and backward op. Is it better to have a test script with flash_attention_backward and forward??

KarthikSundar2002 avatar Nov 21 '25 11:11 KarthikSundar2002

Let us know if this is ready for a review!

sayakpaul avatar Nov 29 '25 11:11 sayakpaul

Yep, ready for review! I tested it with a 4-process setup (2×2 mesh, on cpu) and everything checks out, shapes look good and gradients flow correctly. Looking forward for feedback and happy to address any issues.

Bissmella avatar Nov 29 '25 12:11 Bissmella

I am trying with the following code:

import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device

device = setup_distributed()
    
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]

if dist.get_rank() == 0:
    image.save("output_ua.png")
if dist.is_initialized():
    dist.destroy_process_group()

Run the above with torchrun --nproc-per-node 4 check_unified_attention.py.

And it leads to: https://pastebin.com/A7KkvXH2

sayakpaul avatar Dec 06 '25 03:12 sayakpaul

I am trying with the following code:

import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device

device = setup_distributed()
    
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]

if dist.get_rank() == 0:
    image.save("output_ua.png")
if dist.is_initialized():
    dist.destroy_process_group()

Run the above with torchrun --nproc-per-node 4 check_unified_attention.py.

And it leads to: https://pastebin.com/A7KkvXH2

I spent quite some time investigating this issue but wasn’t able to find the cause. I tried to reproduce it, but the model is too large for the small GPUs I can use, and native_cudnn attention also does not work on simpler GPUs. Does this error occur with TemplatedRingAttention alone? It seems the problem arises with out, prev_out, lse, and prev_lse in the second iteration of the for loop, but none of these tensors originates directly from TemplatedUnifiedAttention. I will continue digging more into this and see if I can identify the issue.

Bissmella avatar Dec 08 '25 10:12 Bissmella

I am trying with the following code:

import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device

device = setup_distributed()
    
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]

if dist.get_rank() == 0:
    image.save("output_ua.png")
if dist.is_initialized():
    dist.destroy_process_group()

Run the above with torchrun --nproc-per-node 4 check_unified_attention.py. And it leads to: https://pastebin.com/A7KkvXH2

I spent quite some time investigating this issue but wasn’t able to find the cause. I tried to reproduce it, but the model is too large for the small GPUs I can use, and native_cudnn attention also does not work on simpler GPUs. Does this error occur with TemplatedRingAttention alone? It seems the problem arises with out, prev_out, lse, and prev_lse in the second iteration of the for loop, but none of these tensors originates directly from TemplatedUnifiedAttention. I will continue digging more into this and see if I can identify the issue.

Oooh finally tracked it down and could reproduce it on cpu! The bug is in the TemplatedRingAttention forward function in these lines:

            if _parallel_config.context_parallel_config.convert_to_fp32:
                out = out.to(torch.float32)
                lse = lse.to(torch.float32)

            lse = lse.unsqueeze(-1)
            if prev_out is not None:
                out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
                lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
            prev_out = out
            prev_lse = lse

        out = out.to(query.dtype)
        lse = lse.squeeze(-1)

That lse = lse.unsqueeze(-1) is unnecessary and causes the issue because it is already done inside the torch.ops.aten._scaled_dot_product_cudnn_attention used by _cudnn_attention_forward_op. See https://github.com/pytorch/pytorch/blob/7a38744ffa3775ace1df4df1d613bb520eb6e456/torch/_meta_registrations.py#L5733 on meta info about the torch.ops.aten._scaled_dot_product_cudnn_attention. So should I commit and push the fix just removing that one line?

Bissmella avatar Dec 08 '25 15:12 Bissmella

Thanks a lot for this investigation. Indeed, that seems to be an issue in PyTorch 2.9. WDYT about the following diff?

Unfold
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
index aaa45c757..0efeb2868 100644
--- a/src/diffusers/models/attention_dispatch.py
+++ b/src/diffusers/models/attention_dispatch.py
@@ -44,6 +44,7 @@ from ..utils import (
     is_xformers_version,
 )
 from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
+from ..utils import is_torch_version
 
 
 if TYPE_CHECKING:
@@ -1186,7 +1187,10 @@ class TemplatedRingAttention(torch.autograd.Function):
                 out = out.to(torch.float32)
                 lse = lse.to(torch.float32)
 
-            lse = lse.unsqueeze(-1)
+            # Refer to:
+            # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
+            if is_torch_version("<", "2.9.0"):
+                lse = lse.unsqueeze(-1)
             if prev_out is not None:
                 out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
                 lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1400,7 +1404,10 @@ def TemplatedUnifiedAttention(
     if return_lse:
         # not sure if this is correct: Assuming (based on forward ops in ringAttention) 
         # the lse is of shape (B, S, H_LOCAL)
-        lse = lse.unsqueeze(-1)  # (B, S, H_LOCAL, 1)
+        # Refer to:
+        # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
+        if is_torch_version("<", "2.9.0"):
+            lse = lse.unsqueeze(-1)  # (B, S, H_LOCAL, 1)
         lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1)
         lse = lse.squeeze(-1)
         return (output, lse)

I also coded up a simple script to compare different backends:

Unfold
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified"],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()

    pipeline = DiffusionPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
    ).to(device)
    # Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
    pipeline.transformer.set_attention_backend("_native_cudnn")

    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    else:
        cp_config = ContextParallelConfig(ulysses_degree=world_size)

    pipeline.transformer.enable_parallelism(config=cp_config)

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    generator = torch.Generator().manual_seed(42)
    image = pipeline(
        prompt,
        guidance_scale=3.5,
        num_inference_steps=50,
        generator=generator,
    ).images[0]

    if dist.get_rank() == 0:
        image.save(f"output_{args.cp_backend}.png")

    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

When I ran the above with torchrun --nproc-per-node 2 check_unified_attention.py --cp-backend {ring,ulysses,unified} (I am on a node of 2 GPUs), I got:

Ring Ulysses Unified
Ring Ulysses Unified

I also changed to cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2) on a node of 4 GPUs, and ran the code with torchrun --nproc-per-node 4 check_unified_attention.py --cp-backend. I got identical output.

sayakpaul avatar Dec 09 '25 04:12 sayakpaul

I think that is perfect, I didn't know specific about torch 2.9. I will apply the diff. Thanks a lot for sharing your script and those amazing photos. Should I convert your script to a test and add it in tests? I think that would be good. Or replace the existing one? I can put some more time on cleaning and adding standard test.

I will just do final test on lse on TemplatedUnifiedAttention and correct if anything wrong. There is a similar issue to this earlier comment in the backward of TemplatedUnifiedAttention and misses one None in the output. Should I add it?

Bissmella avatar Dec 09 '25 09:12 Bissmella

Should I convert your script to a test and add it in tests? I think that would be good. Or replace the existing one? I can put some more time on cleaning and adding standard test.

We need to add dedicated testing for CP x attention backends, anyway. So, we can skip for now. Sufficient documentation should suffice.

There is a similar issue to this https://github.com/huggingface/diffusers/pull/12693#discussion_r2594286403 in the backward of TemplatedUnifiedAttention and misses one None in the output. Should I add it?

Sounds good!

sayakpaul avatar Dec 09 '25 10:12 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bot /style

sayakpaul avatar Dec 11 '25 08:12 sayakpaul

Style bot fixed some files and pushed the changes.

github-actions[bot] avatar Dec 11 '25 08:12 github-actions[bot]

Okay I will add the docs and then remove the test file.

Bissmella avatar Dec 11 '25 08:12 Bissmella

Oups! so sorry for the force push. Just resolved a conflict in the distributed_inference.md in docs. I added the docs and removed the test file.

Bissmella avatar Dec 11 '25 09:12 Bissmella

@bot /style

sayakpaul avatar Dec 13 '25 03:12 sayakpaul

Style bot fixed some files and pushed the changes.

github-actions[bot] avatar Dec 13 '25 03:12 github-actions[bot]

Yes sure, I can do the benchmarking for the three methods.

Bissmella avatar Dec 13 '25 20:12 Bissmella

Cool, I will be curious for the results!

sayakpaul avatar Dec 15 '25 04:12 sayakpaul