[core] Pyramid Attention Broadcast
What does this PR do?
Adds support for Pyramid Attention Broadcast.
- Paper: https://www.arxiv.org/abs/2408.12588
- Project Page: https://oahzxl.github.io/PAB/
- Code: https://github.com/NUS-HPC-AI-Lab/VideoSys
We only add the changes related to attention, and not sequence/CFG parallism due to diffusers primarily being geared towards single-GPU inference.
Usage
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.enable_pyramid_attention_broadcast(
spatial_attn_skip_range=2,
spatial_attn_timestep_range=[100, 850],
)
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
Benchmark
Code
import gc
import torch
from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline, LattePipeline
from diffusers.utils import export_to_video, load_image
from tabulate import tabulate
def reset_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
def pretty_print_results(results, precision: int = 3):
def format_value(value):
if isinstance(value, float):
return f"{value:.{precision}f}"
return value
filtered_table = {k: format_value(v) for k, v in results.items()}
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))
@torch.no_grad()
def test_cogvideox_5b():
reset_memory()
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
model_memory = torch.cuda.memory_allocated() / 1024**3
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
# Warmup
_ = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=2,
generator=torch.Generator().manual_seed(31337),
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
latent = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(31337),
output_type="latent",
).frames
end.record()
torch.cuda.synchronize()
normal_time = start.elapsed_time(end) / 1000
normal_memory = torch.cuda.max_memory_reserved() / 1024**3
video = pipe.decode_latents(latent)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, "outputs/cogvideox_5b.mp4", fps=8)
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=[100, 850])
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
latent = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(31337),
output_type="latent",
).frames
end.record()
torch.cuda.synchronize()
pab_time = start.elapsed_time(end) / 1000
pab_memory = torch.cuda.max_memory_reserved() / 1024**3
video = pipe.decode_latents(latent)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, "outputs/cogvideox_pab_5b.mp4", fps=8)
return {
"model_memory": model_memory,
"normal_memory": normal_memory,
"pab_memory": pab_memory,
"normal_time": normal_time,
"pab_time": pab_time,
}
@torch.no_grad()
def test_cogvideox_5b_i2v():
reset_memory()
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
pipe.to("cuda")
model_memory = torch.cuda.memory_allocated() / 1024**3
prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
# Warmup
_ = pipe(
prompt=prompt,
image=image,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=2,
generator=torch.Generator().manual_seed(31337),
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
latent = pipe(
prompt=prompt,
image=image,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(31337),
output_type="latent",
).frames
end.record()
torch.cuda.synchronize()
normal_time = start.elapsed_time(end) / 1000
normal_memory = torch.cuda.max_memory_reserved() / 1024**3
video = pipe.decode_latents(latent)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, "outputs/cogvideox_5b_i2v.mp4", fps=8)
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=[100, 850])
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
latent = pipe(
prompt=prompt,
image=image,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(31337),
output_type="latent",
).frames
end.record()
torch.cuda.synchronize()
pab_time = start.elapsed_time(end) / 1000
pab_memory = torch.cuda.max_memory_reserved() / 1024**3
video = pipe.decode_latents(latent)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, "outputs/cogvideox_pab_5b_i2v.mp4", fps=8)
return {
"model_memory": model_memory,
"normal_memory": normal_memory,
"pab_memory": pab_memory,
"normal_time": normal_time,
"pab_time": pab_time,
}
@torch.no_grad()
def test_latte():
reset_memory()
pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.bfloat16)
pipe.to("cuda")
model_memory = torch.cuda.memory_allocated() / 1024**3
prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
# Warmup
_ = pipe(
prompt=prompt,
video_length=16,
num_inference_steps=2,
generator=torch.Generator().manual_seed(31337),
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
latent = pipe(
prompt=prompt,
num_inference_steps=50,
video_length=16,
generator=torch.Generator().manual_seed(31337),
output_type="latent",
).frames
end.record()
torch.cuda.synchronize()
normal_time = start.elapsed_time(end) / 1000
normal_memory = torch.cuda.max_memory_reserved() / 1024**3
video = pipe.decode_latents(latent, video_length=16)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, "outputs/latte.mp4", fps=8)
pipe.enable_pyramid_attention_broadcast(
spatial_attn_skip_range=2,
cross_attn_skip_range=6,
spatial_attn_timestep_range=[100, 800],
cross_attn_timestep_range=[100, 800],
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
latent = pipe(
prompt=prompt,
video_length=16,
num_inference_steps=50,
generator=torch.Generator().manual_seed(31337),
output_type="latent",
).frames
end.record()
torch.cuda.synchronize()
pab_time = start.elapsed_time(end) / 1000
pab_memory = torch.cuda.max_memory_reserved() / 1024**3
video = pipe.decode_latents(latent, video_length=16)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, "outputs/latte_pab.mp4", fps=8)
return {
"model_memory": model_memory,
"normal_memory": normal_memory,
"pab_memory": pab_memory,
"normal_time": normal_time,
"pab_time": pab_time,
}
for fn in [test_cogvideox_5b, test_cogvideox_5b_i2v, test_latte]:
print(f"Running {fn.__name__}")
results = fn()
pretty_print_results(results)
print()
- For CogVideoX, we are generating 49 frames at 720 x 480 resolution
- For Latte, we are generating 16 frames at 512 x 512 resolution
Following are the benchmarks based on the above script (note that no memory optimizations are enabled):
| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup |
|---|---|---|---|---|---|---|
| Cog-2b T2V | 12.55 | 35.342 | 35.342 | 86.915 | 63.914 | 1.359 |
| Cog-5b T2V | 19.66 | 40.945 | 40.945 | 246.152 | 168.59 | 1.460 |
| Cog-5b I2V | 19.764 | 42.74 | 42.74 | 246.867 | 170.111 | 1.451 |
| Latte | 11.007 | 25.594 | 25.594 | 28.026 | 24.073 | 1.164 |
| CogVideoX-2b T2V | |
| CogVideoX-5b T2V | |
| CogVideoX-5b I2V | |
| Latte | |
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@yiyixuxu @sayakpaul
@oahzxl for PAB, @zRzRzRzRzRzRzR for CogVideoX related changes, @maxin-cn for Latte related changes
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
I can't seem to replicate the results for PAB on CogVideoX-5b T2V or I2V. This is what I get:
| CogVideoX-5b T2V | |
| CogVideoX-5b I2V | |
@oahzxl Would you be able to give this a review when free? I'm unable to figure out what I'm doing wrong that's causing poor results in these cases. Thank you!
sure, thanks for your code! i guess it may be related with pos embed or encoder concat of 5b model. i can have a look at the code soon!
hi, i have done some experiments and here are my conclusions:
i first try a simple implementation
the org attention is:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
for simplicty, i just add pab's logic here:
# in init
self.attn_count = 0
self.last_attn = None
...
# in forward
if (10 < self.attn_count < 45) and (self.attn_count % 2 != 0):
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
self.last_attn = attn_hidden_states, attn_encoder_hidden_states
this should be exactly the same as the logic in pab processor.
then i find pab will be numerically unstable with fp16 for cogvideox-5b. so i change to bfloat16, and it works!
https://github.com/user-attachments/assets/28e6c4d7-0461-4c87-a869-e465ca3e72d1
->> so the first problem is float16!
- then i test pab processor
but fail even if i use bfloat16
i find even i set spatial_attn_skip_range to 1 (which means no broadcast), it will also generate random noise.
->> so i think the second problem is in processor, but no clue for now
hope it can help you!
Thank you so much for the investigation! I think I found the bug. This line checks if the processor signature supports a specific keyword arguments before passing them. In this case, since we replace the attention processor with PyramidAttentionBroadcastAttentionProcessor, which only has args and kwargs, it drops the image_rotary_emb kwargs necessary for generation. So, RoPE embeddings are not passed at all causing bad video.
glad i can help :) !
I will do some tests with Mochi, now that it is in, and push the relevant changes here if okay. If not, we could maybe do it in a separate PR
I think hooks can work to actually capture/cache intermediate values. Not sure if a Pipeline Mixin is the way to go for enabling/disabling the cache.
There seem to be a few methods for cacheing in video (PAB, AdaCache, FasterCache) etc. We might not need to support all of them, but it would be nice to be able to swap out cacheing mechanisms without having to add new Mixins to the Pipelines.
IMO we enable/disable the cache at the model level. e.g.
from diffusers import CogVideoXTransformer3DModel
model = CogVideoXTransformer3DModel.from_pretrained("...")
# method name TBD
model.enable_cached_inference(cache_type="pyramind_attention_broadcast", cache_kwargs={})
So we could create something like a CachedInferenceMixin that can apply the appropriate cache to the model and we add a class attribute in the model for supported cache types.
Enabling cacheing at the model level also lets us apply different cache mechanisms to different components. e.g
from diffusers import AutoencoderKLCogVideoX
model = AutoencoderKLCogVideoX.from_pretrained("...")
model.enable_cached_inference(cache_type="context_parallel")
Although I'm not sure how easy it would be to rewrite context parallel cache with hooks?
CachedInferenceMixin is more preferable to me, too! Do you envision this as a base class that other caching mechanisms would have to subclass from?
I think we will have a better idea of the common methods each caching mechanism needs to implement and have them in the base CachedInferenceMixin class and override as needed.
Thoughts? @DN6
Jotting some points based on Slack convos and in-person convos.
It'd be prudent to consider an offloading mechanism for the caching utilities we are planning similar to how it's done in transformers. When CUDA is available, this should also be done with CUDA streams so as to be able to overlap computations and communications.
This can be made compatible with torch.compile() too as long as we're not mutating the internal model-level states. https://gist.github.com/gau-nernst/9408e13c32d3c6e7025d92cce6cba140 gives us an amazing example of making it all work with torch.compile().
@DN6 Addressed the review comments. Could you give this another look?
We need to make some more updates here before merging to address the case of using multiple hooks at once. The current implementation does not really work, if say both FP8 and PAB are enabled together. I will take it up in this PR before merging after layerwise upcasting is merged: https://github.com/huggingface/diffusers/pull/10347
This has already been addressed in group offloading PR but that will take some more time to complete: https://github.com/huggingface/diffusers/pull/10503
With the latest changes, it is now possible to use multiple forward-modifying hooks now. Here's an example with FP8 layerwise-upcasting and PAB:
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
set_verbosity_debug()
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(150, 700),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
pipe.transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
I think we're good to merge now and also got the approval from Dhruv after working together on latest changes! Thanks for the patience and the reviews everyone :hugs: Will merge once CI is green and wrap up the open cache PRs
@oahzxl Congratulations on the success of your new work - Data centric parallel! I also really liked reading about the pyramid activation checkpointing that was introduced in VideoSys. Thanks for your patience and help, and also for your work that inspired multiple other papers researching caching mechanism specific to video models. We will be sure to integrate as much as possible to make the methods more easily accessible :)