Add Unified Sequence Parallel attention
What does this PR do?
This is a draft implementation of the Unified SP attention approach.
- Implements
_all_to_all_dim_exchangewith custom scatter and gather indices - Implements
TemplatedUnifiedAttention
Core implementation complete, needs:
- [x] Testing
- [x] Validation
It would be nice to get a testing script so that we can quickly check things.
I added a basic test script with a simple forward and backward op. Is it better to have a test script with flash_attention_backward and forward??
Let us know if this is ready for a review!
Yep, ready for review! I tested it with a 4-process setup (2×2 mesh, on cpu) and everything checks out, shapes look good and gradients flow correctly. Looking forward for feedback and happy to address any issues.
I am trying with the following code:
import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
device = torch.device(f"cuda:{dist.get_rank()}")
torch.cuda.set_device(device)
return device
device = setup_distributed()
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]
if dist.get_rank() == 0:
image.save("output_ua.png")
if dist.is_initialized():
dist.destroy_process_group()
Run the above with torchrun --nproc-per-node 4 check_unified_attention.py.
And it leads to: https://pastebin.com/A7KkvXH2
I am trying with the following code:
import torch from torch import distributed as dist from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig def setup_distributed(): if not dist.is_initialized(): dist.init_process_group(backend="nccl") device = torch.device(f"cuda:{dist.get_rank()}") torch.cuda.set_device(device) return device device = setup_distributed() # Need to add parallel support for this. # pipeline.transformer.set_attention_backend("flash_hub") pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, ).to(device) pipeline.transformer.set_attention_backend("_native_cudnn") pipeline.transformer.enable_parallelism( config=ContextParallelConfig(ulysses_degree=2, ring_degree=2) ) prompt = """ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain """ generator = torch.Generator().manual_seed(42) image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0] if dist.get_rank() == 0: image.save("output_ua.png") if dist.is_initialized(): dist.destroy_process_group()Run the above with
torchrun --nproc-per-node 4 check_unified_attention.py.And it leads to: https://pastebin.com/A7KkvXH2
I spent quite some time investigating this issue but wasn’t able to find the cause. I tried to reproduce it, but the model is too large for the small GPUs I can use, and native_cudnn attention also does not work on simpler GPUs.
Does this error occur with TemplatedRingAttention alone? It seems the problem arises with out, prev_out, lse, and prev_lse in the second iteration of the for loop, but none of these tensors originates directly from TemplatedUnifiedAttention. I will continue digging more into this and see if I can identify the issue.
I am trying with the following code:
import torch from torch import distributed as dist from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig def setup_distributed(): if not dist.is_initialized(): dist.init_process_group(backend="nccl") device = torch.device(f"cuda:{dist.get_rank()}") torch.cuda.set_device(device) return device device = setup_distributed() # Need to add parallel support for this. # pipeline.transformer.set_attention_backend("flash_hub") pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, ).to(device) pipeline.transformer.set_attention_backend("_native_cudnn") pipeline.transformer.enable_parallelism( config=ContextParallelConfig(ulysses_degree=2, ring_degree=2) ) prompt = """ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain """ generator = torch.Generator().manual_seed(42) image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0] if dist.get_rank() == 0: image.save("output_ua.png") if dist.is_initialized(): dist.destroy_process_group()Run the above with
torchrun --nproc-per-node 4 check_unified_attention.py. And it leads to: https://pastebin.com/A7KkvXH2I spent quite some time investigating this issue but wasn’t able to find the cause. I tried to reproduce it, but the model is too large for the small GPUs I can use, and
native_cudnnattention also does not work on simpler GPUs. Does this error occur withTemplatedRingAttentionalone? It seems the problem arises without,prev_out,lse, andprev_lsein the second iteration of theforloop, but none of these tensors originates directly fromTemplatedUnifiedAttention. I will continue digging more into this and see if I can identify the issue.
Oooh finally tracked it down and could reproduce it on cpu! The bug is in the TemplatedRingAttention forward function in these lines:
if _parallel_config.context_parallel_config.convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse
out = out.to(query.dtype)
lse = lse.squeeze(-1)
That lse = lse.unsqueeze(-1) is unnecessary and causes the issue because it is already done inside the torch.ops.aten._scaled_dot_product_cudnn_attention used by _cudnn_attention_forward_op. See https://github.com/pytorch/pytorch/blob/7a38744ffa3775ace1df4df1d613bb520eb6e456/torch/_meta_registrations.py#L5733 on meta info about the torch.ops.aten._scaled_dot_product_cudnn_attention.
So should I commit and push the fix just removing that one line?
Thanks a lot for this investigation. Indeed, that seems to be an issue in PyTorch 2.9. WDYT about the following diff?
Unfold
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
index aaa45c757..0efeb2868 100644
--- a/src/diffusers/models/attention_dispatch.py
+++ b/src/diffusers/models/attention_dispatch.py
@@ -44,6 +44,7 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
+from ..utils import is_torch_version
if TYPE_CHECKING:
@@ -1186,7 +1187,10 @@ class TemplatedRingAttention(torch.autograd.Function):
out = out.to(torch.float32)
lse = lse.to(torch.float32)
- lse = lse.unsqueeze(-1)
+ # Refer to:
+ # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
+ if is_torch_version("<", "2.9.0"):
+ lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1400,7 +1404,10 @@ def TemplatedUnifiedAttention(
if return_lse:
# not sure if this is correct: Assuming (based on forward ops in ringAttention)
# the lse is of shape (B, S, H_LOCAL)
- lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
+ # Refer to:
+ # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
+ if is_torch_version("<", "2.9.0"):
+ lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1)
lse = lse.squeeze(-1)
return (output, lse)
I also coded up a simple script to compare different backends:
Unfold
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cp-backend",
type=str,
choices=["ring", "ulysses", "unified"],
default="ulysses",
help="Context parallel backend to use.",
)
return parser.parse_args()
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
return device
def main():
args = parse_args()
device = setup_distributed()
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
).to(device)
# Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
pipeline.transformer.set_attention_backend("_native_cudnn")
if args.cp_backend == "ring":
cp_config = ContextParallelConfig(ring_degree=world_size)
elif args.cp_backend == "unified":
cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
else:
cp_config = ContextParallelConfig(ulysses_degree=world_size)
pipeline.transformer.enable_parallelism(config=cp_config)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
generator = torch.Generator().manual_seed(42)
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=50,
generator=generator,
).images[0]
if dist.get_rank() == 0:
image.save(f"output_{args.cp_backend}.png")
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
When I ran the above with torchrun --nproc-per-node 2 check_unified_attention.py --cp-backend {ring,ulysses,unified} (I am on a node of 2 GPUs), I got:
| Ring | Ulysses | Unified |
|---|---|---|
|
|
|
|
I also changed to cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2) on a node of 4 GPUs, and ran the code with torchrun --nproc-per-node 4 check_unified_attention.py --cp-backend. I got identical output.
I think that is perfect, I didn't know specific about torch 2.9. I will apply the diff. Thanks a lot for sharing your script and those amazing photos. Should I convert your script to a test and add it in tests? I think that would be good. Or replace the existing one? I can put some more time on cleaning and adding standard test.
I will just do final test on lse on TemplatedUnifiedAttention and correct if anything wrong.
There is a similar issue to this earlier comment in the backward of TemplatedUnifiedAttention and misses one None in the output. Should I add it?
Should I convert your script to a test and add it in tests? I think that would be good. Or replace the existing one? I can put some more time on cleaning and adding standard test.
We need to add dedicated testing for CP x attention backends, anyway. So, we can skip for now. Sufficient documentation should suffice.
There is a similar issue to this https://github.com/huggingface/diffusers/pull/12693#discussion_r2594286403 in the backward of TemplatedUnifiedAttention and misses one None in the output. Should I add it?
Sounds good!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@bot /style
Style bot fixed some files and pushed the changes.
Okay I will add the docs and then remove the test file.
Oups! so sorry for the force push. Just resolved a conflict in the distributed_inference.md in docs. I added the docs and removed the test file.
@bot /style
Style bot fixed some files and pushed the changes.
Yes sure, I can do the benchmarking for the three methods.
Cool, I will be curious for the results!