Attention Dispatcher
Usage
# test.py
import torch
from diffusers import Lumina2Pipeline, attention_backend
pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A cat holding a sign that says 'Hello, World!' in a colorful park with flowers and trees"
with attention_backend("sage_varlen"):
image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")
# fails because flex attention requires head dim to be a power of 2
DIFFUSERS_ATTN_PROVIDER="flex" CUDA_VISIBLE_DEVICES=3 python3 test.py
# dispatches to cudnn internally in pytorch, so it's the same as using "_native_cudnn" (see below)
DIFFUSERS_ATTN_PROVIDER="native" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="flash_varlen" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="sage_varlen" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="_native_cudnn" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="_native_efficient" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="xformers" CUDA_VISIBLE_DEVICES=3 python3 test.py
attention-only benchmark
import torch
from diffusers.models.attention_dispatch import attention_backend, dispatch_attention_fn
torch.manual_seed(0)
# Wan 1.3B/CogVideoX
batch = 1
num_heads = 12
head_dim = 128
dtype = torch.bfloat16
resolutions = [(1, 512, 512), (1, 1024, 1024), (49, 480, 720), (29, 1024, 1024), (81, 480, 832)]
seq_lens = [((res[0] - 1) // 4 + 1) * res[1] * res[2] // 8 // 8 // 4 for res in resolutions]
print("Sequence lengths:", seq_lens)
for seq_len in seq_lens:
flops = 4 * batch * num_heads * head_dim * seq_len * seq_len
torch.manual_seed(0)
query = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
key = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
value = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
results = {}
for backend in ["flash", "flash_varlen", "_native_flash", "_native_cudnn", "_native_efficient", "xformers", "_sage_qk_int8_pv_fp16_cuda"]:
with attention_backend(backend):
for _ in range(5):
# Warmup
_ = dispatch_attention_fn(query, key, value)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = dispatch_attention_fn(query, key, value)
end.record()
torch.cuda.synchronize()
elapsed_time = start.elapsed_time(end) / 1000
results[backend] = elapsed_time
tflops_s_flash = flops / results["flash"] / 1e12
tflops_s_flash_varlen = flops / results["flash_varlen"] / 1e12
tflops_s_native_flash = flops / results["_native_flash"] / 1e12
tflops_s_native_cudnn = flops / results["_native_cudnn"] / 1e12
tflops_s_native_efficient = flops / results["_native_efficient"] / 1e12
tflops_s_xformers = flops / results["xformers"] / 1e12
tflops_s_sage_qk_int8_pv_fp16_cuda = flops / results["_sage_qk_int8_pv_fp16_cuda"] / 1e12
print()
print(f"Shape: {query.shape}")
print(f"TFLOPs: {flops / 1e12:.2f}")
print("===== TFLOPS =====")
print(f" (flash): {tflops_s_flash:.2f}")
print(f" (flash_varlen): {tflops_s_flash_varlen:.2f}")
print(f" (native_flash): {tflops_s_native_flash:.2f}")
print(f" (native_cudnn): {tflops_s_native_cudnn:.2f}")
print(f" (native_efficient): {tflops_s_native_efficient:.2f}")
print(f" (xformers): {tflops_s_xformers:.2f}")
print(f"(_sage_qk_int8_pv_fp16_cuda): {tflops_s_sage_qk_int8_pv_fp16_cuda:.2f}")
print("==========")
Model benchmark
import argparse
import gc
import pathlib
import traceback
import git
import pandas as pd
import torch
import torch.nn.attention.flex_attention
from diffusers import (
AllegroPipeline,
CogVideoXPipeline,
FluxPipeline,
HunyuanVideoPipeline,
LattePipeline,
LTXPipeline,
MochiPipeline,
WanPipeline,
AttentionBackendName,
attention_backend,
)
from diffusers.hooks import apply_group_offloading
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate
repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch
torch.nn.attention.flex_attention.flex_attention = torch.compile(torch.nn.attention.flex_attention.flex_attention, mode="max-autotune", dynamic=False, fullgraph=True)
torch.nn.attention.flex_attention.create_block_mask = torch.compile(torch.nn.attention.flex_attention.create_block_mask, mode="max-autotune", dynamic=False, fullgraph=True)
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_accumulation = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
def pretty_print_results(results, precision: int = 3):
def format_value(value):
if isinstance(value, float):
return f"{value:.{precision}f}"
return value
filtered_table = {k: format_value(v) for k, v in results.items()}
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))
def benchmark_fn(f, *args, **kwargs):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
output = f(*args, **kwargs)
end.record()
torch.cuda.synchronize()
elapsed_time = round(start.elapsed_time(end) / 1000, 3)
return elapsed_time, output
def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
model_id = "rhymes-ai/Allegro"
cache_dir = None
pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
pipe.to("cuda")
pipe.vae.enable_tiling()
if compile:
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
)
for key, value in list(kwargs.items()):
if torch.is_tensor(value):
kwargs[key] = value.to(device="cuda", dtype=dtype)
generation_kwargs = {
"prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
"height": 720,
"width": 1280,
"num_inference_steps": 50,
"guidance_scale": 5.0,
**kwargs,
}
return pipe, generation_kwargs
def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
model_id = "THUDM/CogVideoX-5b"
cache_dir = None
pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
pipe.to("cuda")
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=(
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
),
device="cuda",
dtype=dtype,
)
pipe.text_encoder.to("cpu")
for key, value in list(kwargs.items()):
if torch.is_tensor(value):
kwargs[key] = value.to(device="cuda", dtype=dtype)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"height": 480,
"width": 720,
"num_frames": 49,
"num_inference_steps": 50,
"guidance_scale": 5.0,
**kwargs,
}
return pipe, generation_kwargs
def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
model_id = "black-forest-labs/FLUX.1-dev"
cache_dir = "/raid/.cache/huggingface"
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
pipe.vae.enable_tiling()
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(
prompt="A cat holding a sign that says hello world", prompt_2=None, device="cuda"
)
pipe.text_encoder.to("cpu")
pipe.text_encoder_2.to("cpu")
del pipe.text_encoder
del pipe.text_encoder_2
pipe.text_encoder = None
pipe.text_encoder_2 = None
pipe.to("cuda")
for key, value in list(kwargs.items()):
if torch.is_tensor(value):
kwargs[key] = value.to(device="cuda", dtype=dtype)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"height": 768,
"width": 768,
"num_inference_steps": 50,
"guidance_scale": 5.0,
**kwargs,
}
return pipe, generation_kwargs
def prepare_hunyuan_video(dtype: torch.dtype, compile: bool = False, **kwargs):
model_id = "hunyuanvideo-community/HunyuanVideo"
cache_dir = None
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.float16, cache_dir=cache_dir
)
pipe.to("cuda")
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
prompt="A cat wearing sunglasses and working as a lifeguard at pool.", device="cuda", dtype=torch.float16
)
pipe.text_encoder.to("cpu")
pipe.text_encoder_2.to("cpu")
for key, value in list(kwargs.items()):
if torch.is_tensor(value):
kwargs[key] = value.to(device="cuda", dtype=dtype)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"height": 320,
"width": 512,
"num_frames": 61,
"num_inference_steps": 30,
}
return pipe, generation_kwargs
def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
model_id = "maxin-cn/Latte-1"
cache_dir = None
pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
pipe.to("cuda")
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt="A cat wearing sunglasses and working as a lifeguard at pool.",
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
device="cuda",
)
pipe.text_encoder.to("cpu")
for key, value in list(kwargs.items()):
if torch.is_tensor(value):
kwargs[key] = value.to(device="cuda", dtype=dtype)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"height": 512,
"width": 512,
"video_length": 16,
"num_inference_steps": 50,
}
return pipe, generation_kwargs
def prepare_ltx_video(dtype: torch.dtype, compile: bool = False, **kwargs):
model_id = "a-r-r-o-w/LTX-Video-diffusers"
cache_dir = None
pipe = LTXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
pipe.to("cuda")
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(
prompt="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
device="cuda",
)
pipe.text_encoder.to("cpu")
for key, value in list(kwargs.items()):
if torch.is_tensor(value):
kwargs[key] = value.to(device="cuda", dtype=dtype)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"width": 768,
"height": 512,
"num_frames": 161,
"num_inference_steps": 50,
}
return pipe, generation_kwargs
def prepare_mochi(dtype: torch.dtype, compile: bool = False, **kwargs):
model_id = "genmo/mochi-1-preview"
cache_dir = None
pipe = MochiPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
pipe.to("cuda")
pipe.vae.enable_tiling()
for key, value in list(kwargs.items()):
if torch.is_tensor(value):
kwargs[key] = value.to(device="cuda", dtype=dtype)
generation_kwargs = {
"prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
"height": 480,
"width": 848,
"num_frames": 85,
"num_inference_steps": 50,
}
return pipe, generation_kwargs
def prepare_wan(dtype: torch.dtype, compile: bool = False, **kwargs):
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
cache_dir = None
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "worst quality, low quality, blurry, distorted, out of focus, bad composition"
pipe.text_encoder.to("cuda")
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
device="cuda",
)
pipe.text_encoder.to("cpu")
del pipe.text_encoder
pipe.text_encoder = None
pipe.to("cuda")
for key, value in list(kwargs.items()):
if torch.is_tensor(value):
kwargs[key] = value.to(device="cuda", dtype=dtype)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"height": 480,
"width": 832,
"num_frames": 81,
"guidance_scale": 5.0,
"num_inference_steps": 30,
**kwargs,
}
return pipe, generation_kwargs
def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
filename = f"{filename.as_posix()}.mp4"
video = pipe.decode_latents(latents)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, filename, fps=8)
return filename
def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
filename = f"{filename.as_posix()}.mp4"
video = pipe.decode_latents(latents)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, filename, fps=8)
return filename
def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
height = kwargs["height"]
width = kwargs["width"]
filename = f"{filename.as_posix()}.png"
latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
image = pipe.vae.decode(latents, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
image.save(filename)
return filename
def decode_hunyuan_video(pipe: HunyuanVideoPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
filename = f"{filename.as_posix()}.mp4"
latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
video = pipe.vae.decode(latents, return_dict=False)[0]
video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
export_to_video(video, filename, fps=8)
return filename
def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
filename = f"{filename.as_posix()}.mp4"
video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, filename, fps=8)
return filename
def decode_ltx_video(pipe: LTXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
filename = f"{filename.as_posix()}.mp4"
latent_num_frames = (kwargs["num_frames"] - 1) // pipe.vae_temporal_compression_ratio + 1
latent_height = kwargs["height"] // pipe.vae_spatial_compression_ratio
latent_width = kwargs["width"] // pipe.vae_spatial_compression_ratio
latents = pipe._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
pipe.transformer_spatial_patch_size,
pipe.transformer_temporal_patch_size,
)
latents = pipe._denormalize_latents(
latents, pipe.vae.latents_mean, pipe.vae.latents_std, pipe.vae.config.scaling_factor
)
latents = latents.to(pipe.vae.dtype)
timestep = None
video = pipe.vae.decode(latents, timestep, return_dict=False)[0]
video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
export_to_video(video, filename, fps=24)
return filename
def decode_mochi(pipe: MochiPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
filename = f"{filename.as_posix()}.mp4"
latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
video = pipe.vae.decode(latents, return_dict=False)[0]
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, filename, fps=8)
return filename
def decode_wan(pipe: WanPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
filename = f"{filename.as_posix()}.mp4"
latents = latents.to(pipe.vae.dtype)
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean)
.view(1, pipe.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = pipe.vae.decode(latents, return_dict=False)[0]
video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
export_to_video(video, filename, fps=16)
return filename
def reset_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
MODEL_MAPPING = {
"allegro": {
"prepare": prepare_allegro,
"decode": decode_allegro,
},
"cogvideox-1.0": {
"prepare": prepare_cogvideox_1_0,
"decode": decode_cogvideox_1_0,
},
"flux": {
"prepare": prepare_flux,
"decode": decode_flux,
},
"hunyuan_video": {
"prepare": prepare_hunyuan_video,
"decode": decode_hunyuan_video,
},
"latte": {
"prepare": prepare_latte,
"decode": decode_latte,
},
"ltx_video": {
"prepare": prepare_ltx_video,
"decode": decode_ltx_video,
},
"mochi": {
"prepare": prepare_mochi,
"decode": decode_mochi,
},
"wan": {
"prepare": prepare_wan,
"decode": decode_wan,
}
}
STR_TO_COMPUTE_DTYPE = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
def run_inference(pipe, generation_kwargs):
generator = torch.Generator().manual_seed(181201)
output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
torch.cuda.synchronize()
return output
from diffusers.hooks import ModelHook, HookRegistry
from accelerate.utils import send_to_device
class MoveToCUDAHook(ModelHook):
def pre_forward(self, module, *args, **kwargs):
args = send_to_device(args, "cuda")
kwargs = send_to_device(kwargs, "cuda")
return args, kwargs
def post_forward(self, module, output):
output = send_to_device(output, "cpu")
return output
@torch.no_grad()
def main(model_id: str, output_dir: str, dtype: str, offloading_type: str, num_blocks_per_group: int, use_stream: bool, compile: bool, attn_provider: str, num_images_per_prompt: int):
if attn_provider == "flex":
import torch.nn.attention.flex_attention as flex_attention
flex_attention.flex_attention = torch.compile(flex_attention.flex_attention, mode="max-autotune-no-cudagraphs", fullgraph=True)
flex_attention.create_block_mask = torch.compile(flex_attention.create_block_mask, mode="max-autotune-no-cudagraphs", fullgraph=True)
if model_id not in MODEL_MAPPING.keys():
raise ValueError("Unsupported `model_id` specified.")
output_dir = pathlib.Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
csv_filename = output_dir / f"{model_id}.csv"
compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
model = MODEL_MAPPING[model_id]
reset_memory()
try:
# 1. Prepare inputs and generation kwargs
pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)
extra_keys = {}
if model_id == "wan":
extra_keys = {"num_videos_per_prompt": num_images_per_prompt}
else:
extra_keys = {"num_images_per_prompt": num_images_per_prompt}
generation_kwargs.update(extra_keys)
# 2. Apply group offloading
if offloading_type == "model":
pipe.enable_model_cpu_offload()
elif offloading_type == "sequential":
pipe.enable_sequential_cpu_offload()
elif offloading_type in ["block_level", "leaf_level"]:
apply_group_offloading(
pipe.transformer,
offload_type=offloading_type,
num_blocks_per_group=num_blocks_per_group,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
non_blocking=True,
use_stream=use_stream,
)
else:
pipe.transformer.to("cuda")
# registry = HookRegistry.check_if_exists_or_initialize(pipe.transformer)
# registry.register_hook(MoveToCUDAHook(), "MoveToCUDAHook")
pipe.vae.to("cuda")
torch.cuda.synchronize()
reset_memory()
model_max_memory_reserved = round(torch.cuda.max_memory_allocated() / 1024**3, 3)
if compile:
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False
)
registry_vae = HookRegistry.check_if_exists_or_initialize(pipe.vae.decoder)
registry_vae.register_hook(MoveToCUDAHook(), "MoveToCUDAHook")
# 3. Warmup
num_warmups = 1
original_num_inference_steps = generation_kwargs["num_inference_steps"]
generation_kwargs["num_inference_steps"] = 2
with attention_backend(attn_provider):
for _ in range(num_warmups):
run_inference(pipe, generation_kwargs)
generation_kwargs["num_inference_steps"] = original_num_inference_steps
# 4. Benchmark
with attention_backend(attn_provider):
time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
inference_max_memory_reserved = round(torch.cuda.max_memory_allocated() / 1024**3, 3)
# 5. Decode latents
filename = output_dir / f"{model_id}---attn_provider-{attn_provider}---dtype-{dtype}---offloading_type-{offloading_type}---num_blocks_per_group-{num_blocks_per_group}---use_stream-{use_stream}---compile-{compile}"
filename = model["decode"](
pipe,
latents,
filename,
height=generation_kwargs["height"],
width=generation_kwargs["width"],
num_frames=generation_kwargs.get("num_frames", None),
video_length=generation_kwargs.get("video_length", None),
)
# 6. Save artifacts
info = {
"model_id": model_id,
"attn_provider": attn_provider,
"time": time,
"offloading_type": offloading_type,
"use_stream": use_stream,
"num_blocks": num_blocks_per_group,
"model_memory": model_max_memory_reserved,
"inference_memory": inference_max_memory_reserved,
"compile": compile,
"compute_dtype": dtype,
"branch": branch,
"filename": filename,
"exception": None,
}
except Exception as e:
print(f"An error occurred: {e}")
traceback.print_exc()
# 6. Save artifacts
info = {
"model_id": model_id,
"attn_provider": attn_provider,
"time": None,
"offloading_type": offloading_type,
"use_stream": use_stream,
"num_blocks": num_blocks_per_group,
"model_memory": None,
"inference_memory": None,
"compile": compile,
"compute_dtype": dtype,
"branch": branch,
"filename": None,
"exception": str(e),
}
pretty_print_results(info, precision=3)
df = pd.DataFrame([info])
df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
type=str,
default="flux",
choices=["flux", "cogvideox-1.0", "latte", "allegro", "hunyuan_video", "mochi", "ltx_video", "wan"],
help="Model to run benchmark for.",
)
parser.add_argument("--attn_provider", type=str, default="native", choices=[x.value for x in AttentionBackendName.__members__.values()])
parser.add_argument("--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt.")
parser.add_argument(
"--output_dir", required=True, type=str, help="Path where the benchmark artifacts and outputs are the be saved."
)
parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
parser.add_argument("--offloading_type", type=str, default="none", choices=["none", "model", "block_level", "leaf_level"], help="Type of offloading to use.")
parser.add_argument("--num_blocks_per_group", type=int, default=None, help="Number of layers per group for group offloading.")
parser.add_argument("--use_stream", action="store_true", default=False, help="Whether to use CUDA streams for offloading.")
parser.add_argument(
"--compile",
action="store_true",
default=False,
help="Whether to torch.compile the denoiser.",
)
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
args = parser.parse_args()
if args.verbose:
set_verbosity_debug()
else:
set_verbosity_info()
main(
args.model_id,
args.output_dir,
args.dtype,
args.offloading_type,
args.num_blocks_per_group,
args.use_stream,
args.compile,
args.attn_provider,
args.num_images_per_prompt,
)
Results: 4090
Results with PyTorch 2.7 stable, CUDA 12.6
Wan
| model_id | attn_provider | time | offloading_type | use_stream | num_blocks | model_memory | inference_memory | compile |
|---|---|---|---|---|---|---|---|---|
| wan | flash | 142.816 | none | False | 2.912 | 4.455 | False | |
| wan | flash_varlen | 144.221 | none | False | 2.912 | 4.455 | False | |
| wan | flex | 146.176 | none | False | 2.912 | 4.455 | False | |
| wan | native | 144.692 | none | False | 2.912 | 4.455 | False | |
| wan | _native_cudnn | 144.901 | none | False | 2.912 | 4.455 | False | |
| wan | _native_efficient | 184.593 | none | False | 2.912 | 4.455 | False | |
| wan | _native_flash | 144.611 | none | False | 2.912 | 4.455 | False | |
| wan | sage | 102.281 | none | False | 2.912 | 4.455 | False | |
| wan | sage_varlen | 112.254 | none | False | 2.912 | 4.455 | False | |
| wan | xformers | 142.909 | none | False | 2.912 | 4.455 | False | |
| wan | flash | 147.230 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | flash_varlen | 148.197 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | flex | 150.197 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | native | 148.783 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | _native_cudnn | 149.177 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | _native_efficient | 188.643 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | _native_flash | 148.753 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | sage | 106.032 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | sage_varlen | 116.081 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | xformers | 147.119 | leaf_level | True | 0.249 | 1.819 | False |
Results: A100
Results with PyTorch 2.7 stable, CUDA 12.2
Wan
| model_id | attn_provider | time | offloading_type | use_stream | num_blocks | model_memory | inference_memory | compile |
|---|---|---|---|---|---|---|---|---|
| wan | flash | 123.107 | none | False | 2.912 | 4.455 | False | |
| wan | flash_varlen | 125.355 | none | False | 2.912 | 4.455 | False | |
| wan | flex | 143.088 | none | False | 2.912 | 4.455 | False | |
| wan | native | 130.183 | none | False | 2.912 | 4.455 | False | |
| wan | _native_cudnn | 137.591 | none | False | 2.912 | 4.455 | False | |
| wan | _native_efficient | 183.795 | none | False | 2.912 | 4.455 | False | |
| wan | _native_flash | 131.384 | none | False | 2.912 | 4.455 | False | |
| wan | sage | 119.741 | none | False | 2.912 | 4.455 | False | |
| wan | sage_varlen | 131.515 | none | False | 2.912 | 4.455 | False | |
| wan | xformers | 125.414 | none | False | 2.912 | 4.455 | False | |
| wan | flash | 127.484 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | flash_varlen | 129.351 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | flex | 146.739 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | native | 133.718 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | _native_cudnn | 141.970 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | _native_efficient | 188.268 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | _native_flash | 133.996 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | sage | 123.269 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | sage_varlen | 133.422 | leaf_level | True | 0.249 | 1.819 | False | |
| wan | xformers | 127.743 | leaf_level | True | 0.249 | 1.819 | False |
cc @DN6 @sayakpaul @yiyixuxu
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
The environment vars were initially only for my quick testing from CLI instead of changing the code everytime. We can get rid of it completely.
The intended API in my mind, and what currently exists in the PR is with context managers:
from diffusers import attention_provider
with attention_provider("sage_varlen"):
model(...)
Can change once we finalize something
For the attention config class (if decided to proceed that route), I was thinking of the following APIs:
attn_config = AttentionConfig( attn_implementation="...", enable_gqa=... ) model.set_attn_config(attn_config)
@sayakpaul @DN6 How would you recommend we set per-model attention backend? The backend info needs to be propagated to the attention dispatcher when the forward method is called. The easiest way and how I've done it for training/CP is to attach a simple pre-forward hook that sets the backend, cp_mesh, and any other attributes, when the forward method is invoked. If you have recommendations, I'll modify the implementation accordingly.
Currently, you need to first replace the calls to F.scaled_dot_product_attention with diffusers.models.attention_dispatch.dispatch_attention_fn in the modeling code and then invoke one or more models under the attention_backend context manager:
from diffusers import attention_backend
with attention_backend("flash_varlen"):
output = transformer(...)
If context manager is not used, it defaults to the original behaviour of calling native torch attention.
How would you recommend we set per-model attention backend? The backend info needs to be propagated to the attention dispatcher when the forward method is called. The easiest way and how I've done it for training/CP is to attach a simple pre-forward hook that sets the backend, cp_mesh, and any other attributes, when the forward method is invoked.
I was thinking that upon calling set_attn_config() we would set them? I prefer the set method w.r.t the context manager approach as I feel it's a bit more explicit.
@a-r-r-o-w I was able to run FA3 with your code and here are some results:
Expand
Sequence lengths: [1024, 4096, 17550, 32768, 32760]
Shape: torch.Size([1, 12, 1024, 128])
TFLOPs: 0.01
===== TFLOPS =====
(flash): 67.79
(native_flash): 68.02
(native_cudnn): 60.44
==========
Shape: torch.Size([1, 12, 4096, 128])
TFLOPs: 0.10
===== TFLOPS =====
(flash): 677.87
(native_flash): 325.90
(native_cudnn): 660.36
==========
Shape: torch.Size([1, 12, 17550, 128])
TFLOPs: 1.89
===== TFLOPS =====
(flash): 740.64
(native_flash): 348.37
(native_cudnn): 626.17
==========
Shape: torch.Size([1, 12, 32768, 128])
TFLOPs: 6.60
===== TFLOPS =====
(flash): 724.79
(native_flash): 363.38
(native_cudnn): 701.14
==========
Shape: torch.Size([1, 12, 32760, 128])
TFLOPs: 6.59
===== TFLOPS =====
(flash): 669.26
(native_flash): 353.29
(native_cudnn): 586.65
==========
I can open a PR to your branch for the changes I had to make to make it work. LMK.
@sayakpaul Super cool, thanks! I hope you didn't face too much trouble with building FA3 😅
I actually already have the required changes for FA3 (and some other things like NPU and XLA) locally. I didn't benchmark yet though so thanks for that, and I can push my changes soon
I hope you didn't face too much trouble with building FA3 😅
It just took time. I used Docker instead of the default env of the cluster.
Pushed some changes to support FA3, NPU and XLA. They are all marked private since FA3 is a beta release and NPU and XLA are untested.
Pytorch's cudnn backend is close to FA3, but in almost all problem shapes the latter is faster, similar to FA2 from source
@sayakpaul @DN6 Based on our discussion, I've added support for set_attention_backend at the ModelMixin level. There's now two ways to enable dispatcher.
`set_attention_backend("...")` for diffusers native implementations
import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.attention_processor import Attention
class MyModel(ModelMixin):
def __init__(self):
super().__init__()
self.attention = Attention(
query_dim=10,
heads=2,
dim_head=5,
)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 10),
)
def forward(self, x: torch.Tensor):
return x + self.mlp(x + self.attention(x))
dtype = torch.bfloat16
device = "cuda"
model = MyModel().to(device, dtype=dtype)
input = torch.randn(2, 64, 10).to(device, dtype=dtype)
output_native = model(input)
model.set_attention_backend("flash")
output_flash = model(input)
model.set_attention_backend("sage")
output_sage = model(input)
model.set_attention_backend("_native_math")
output_native_math = model(input)
diff1 = torch.abs(output_native - output_flash).max()
diff2 = torch.abs(output_native - output_sage).max()
diff3 = torch.abs(output_native - output_native_math).max()
print(diff1, diff2, diff3)
context manager for custom implementations
import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.attention_dispatch import dispatch_attention_fn, attention_backend
class AttentionProcessor:
def __call__(
self,
attn,
x: torch.Tensor,
) -> torch.Tensor:
q, k, v = (y.unflatten(2, (attn.heads, -1)).permute(0, 2, 1, 3).contiguous() for y in attn.qkv(x).chunk(3, dim=-1))
return attn.o(dispatch_attention_fn(q, k, v).permute(0, 2, 1, 3).flatten(2))
class Attention(torch.nn.Module):
def __init__(self):
super().__init__()
self.heads = 2
self.qkv = torch.nn.Linear(10, 30)
self.o = torch.nn.Linear(10, 10)
self.processor = AttentionProcessor()
def forward(self, x: torch.Tensor):
return self.processor(self, x)
class MyModel(ModelMixin):
def __init__(self):
super().__init__()
self.attention = Attention()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 10),
)
def forward(self, x: torch.Tensor):
return x + self.mlp(x + self.attention(x))
dtype = torch.bfloat16
device = "cuda"
model = MyModel().to(device, dtype=dtype)
input = torch.randn(2, 64, 10).to(device, dtype=dtype)
output_native = model(input)
with attention_backend("flash"):
output_flash = model(input)
with attention_backend("sage"):
output_sage = model(input)
with attention_backend("_native_math"):
output_native_math = model(input)
diff1 = torch.abs(output_native - output_flash).max()
diff2 = torch.abs(output_native - output_sage).max()
diff3 = torch.abs(output_native - output_native_math).max()
print(diff1, diff2, diff3)
This is looking much much better IMO!
I think with proper documentation, we can make the differences between the scopes of set_attn_backend() and set_attn_processor() much clearer.
Continued in #11916