Sage Attention for diffuser library
**Is your feature request related to a problem? No
Describe the solution you'd like. A clear and concise description of what you want to happen. Incorporate a way to add sage attention to the diffusers library: Flux pipeline, Wan pipeline, etc.
Describe alternatives you've considered. None
Additional context. When I incorporated sage attention in the flux pipeline (text to image) I achieved a 16% speed advantage vs no sage attention. My environment was the same save for including / excluding sage attention in my 4 image benchmark creation.
How to incorporate sage attention? We must consider that this only applies to the Transformer. With this in mind I did the following to the FluxPipeline. Obviously there must be a way to do this via a variable of sorts so that we may/may not run it:
Need some kind of indicator to decide whether to include or not! This must be done before the denoising step in the model pipeline. ` import torch.nn.functional as F sage_function = False try: from sageattention import sageattn self.transformer.scaled_dot_product_attention = F.scaled_dot_product_attention = sageattn sage_function = True except (ImportError): pass
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
` After the denoising step we must remove sage attention else we get a VAE error due to Sage Attn wanting only torch.float16 or torch.bfloat16 dtypes which the VAE doesn't want:
if output_type == "latent": image = latents else: if sage_function: self.transformer.scaled_dot_product_attention = F.scaled_dot_product_attention = torch._C._nn.scaled_dot_product_attention
Hopefully this helps.
Hey, adding different attention backends is going to be prioritized in the next release schedule! We're currently in the process of benchmarking and testing with different models to understand quality tradeoffs and performance gains.
For this specific case, using sage attention is as simple as replacing the F.scaled_dot_product_attention operation, and it should hopefully work for most, if not all, models out of the box
Sage attention on my pc needs to be disabled after the transformer does it's work when using the diffusers library. That was my experience at least.
Why exactly does it need to be disabled after transformer? Do you run into some error?
I ran into an error running the VAE afterwards. It (Sage Attn) complained about the dtype which is torch.float32. Apparently Sage Attn was still controlling something that affected the VAE. That was my experience. Why no one else may have this problem I can't tell you.
I ran a benchmark using sage attn on WAN 2.1 T2V 14B diffuser pipeline. I made some minor adjustments to generate 1 image (i.e. text to image vs text to video) to compare the usage of sage attention for this model. I was somewhat disappointed by the results leading me to believe that some of the claims being made out there are bupkus. I ran 5 test runs throwing out the 1st one as I did for the Flux 1.D model. I only noticed a 4.2% speed increase using sage attn vs not. My environment for Flux 1.D and WAN 2.1 are virtually the same. I have a quantized (qint8) transformer and T5TextEncoder for both setups. Note: I am running with 64GB of RAM and an RTX 4090 with 24GB of VRAM. I am using Windows 10, python version 3.12.5 with CUDNN v 9.7 and torch-2.6.0+cu126.dist-info.
replacing the attn function isn't drop in. they take different parameters, and sageattn doesn't take attn mask or support dropout_p
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Is Sage attention supported? It's very easy to install for Windows and have used it for other apps, works good. I checked but no information available https://github.com/huggingface/diffusers/tree/main/docs/source/en/optimization
Sage attention support is being added in #11368
@varadrane1707 The PR is close to merge and only requires some discussion from other team members. For now, if you really need it, would recommend to create your own fork from the branch and modify it for your usecases. It is actually just a single file you can use with any code base, so feel free to copy-paste it and set torch.nn.functional.scaled_dot_product_attention = dispatch_attention_fn
Thanks @a-r-r-o-w will try this out. Also did you check its compatibility with Para-attn since its used for fbcache implementation. I tried to use to use original sage attention implementation but the output video quality was bad for i2v 720p Wan2.1 model
Hey @varadrane1707, it's hard for me to look at and understand any specific custom code you might be using. The best I can do is point you to a minified example that demonstrates how to use the attention dispatcher with context parallel (essentially the core of ParaAttention): https://gist.github.com/a-r-r-o-w/93b467ddf64bfe9df47fc12fc2ae4fac ; maybe you find it helpful
We have native support for FBC coming in #11180, which works without issues in my tests with the dispatcher, so that should be covered natively soon too.
re: sage attention output being bad -- if you're using custom code from different codebases, it will be hard for me to help. With diffusers-specific code, I haven't noticed a problem in the outputs. If you could open an issue with a simple code example, I could try to help.