diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[core] Pyramid Attention Broadcast

Open a-r-r-o-w opened this issue 1 year ago • 6 comments

What does this PR do?

Adds support for Pyramid Attention Broadcast.

  • Paper: https://www.arxiv.org/abs/2408.12588
  • Project Page: https://oahzxl.github.io/PAB/
  • Code: https://github.com/NUS-HPC-AI-Lab/VideoSys

We only add the changes related to attention, and not sequence/CFG parallism due to diffusers primarily being geared towards single-GPU inference.

Usage

import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16)
pipe.to("cuda")

pipe.enable_pyramid_attention_broadcast(
    spatial_attn_skip_range=2,
    spatial_attn_timestep_range=[100, 850],
)

prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)

Benchmark

Code
import gc

import torch
from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline, LattePipeline
from diffusers.utils import export_to_video, load_image
from tabulate import tabulate


def reset_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


@torch.no_grad()
def test_cogvideox_5b():
    reset_memory()
    
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    model_memory = torch.cuda.memory_allocated() / 1024**3

    prompt = (
        "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
        "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
        "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
        "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
        "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
        "atmosphere of this unique musical performance."
    )

    # Warmup
    _ = pipe(
        prompt=prompt,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=2,
        generator=torch.Generator().manual_seed(31337),
    )

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    latent = pipe(
        prompt=prompt,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(31337),
        output_type="latent",
    ).frames
    end.record()
    torch.cuda.synchronize()
    
    normal_time = start.elapsed_time(end) / 1000
    normal_memory = torch.cuda.max_memory_reserved() / 1024**3
    
    video = pipe.decode_latents(latent)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, "outputs/cogvideox_5b.mp4", fps=8)

    pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=[100, 850])

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    latent = pipe(
        prompt=prompt,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(31337),
        output_type="latent",
    ).frames
    end.record()
    torch.cuda.synchronize()
    
    pab_time = start.elapsed_time(end) / 1000
    pab_memory = torch.cuda.max_memory_reserved() / 1024**3

    video = pipe.decode_latents(latent)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, "outputs/cogvideox_pab_5b.mp4", fps=8)

    return {
        "model_memory": model_memory,
        "normal_memory": normal_memory,
        "pab_memory": pab_memory,
        "normal_time": normal_time,
        "pab_time": pab_time,
    }


@torch.no_grad()
def test_cogvideox_5b_i2v():
    reset_memory()
    
    pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    model_memory = torch.cuda.memory_allocated() / 1024**3

    prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
    image = load_image(
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
    )

    # Warmup
    _ = pipe(
        prompt=prompt,
        image=image,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=2,
        generator=torch.Generator().manual_seed(31337),
    )

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    latent = pipe(
        prompt=prompt,
        image=image,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(31337),
        output_type="latent",
    ).frames
    end.record()
    torch.cuda.synchronize()
    
    normal_time = start.elapsed_time(end) / 1000
    normal_memory = torch.cuda.max_memory_reserved() / 1024**3
    
    video = pipe.decode_latents(latent)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, "outputs/cogvideox_5b_i2v.mp4", fps=8)

    pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=[100, 850])

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    latent = pipe(
        prompt=prompt,
        image=image,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(31337),
        output_type="latent",
    ).frames
    end.record()
    torch.cuda.synchronize()
    
    pab_time = start.elapsed_time(end) / 1000
    pab_memory = torch.cuda.max_memory_reserved() / 1024**3

    video = pipe.decode_latents(latent)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, "outputs/cogvideox_pab_5b_i2v.mp4", fps=8)

    return {
        "model_memory": model_memory,
        "normal_memory": normal_memory,
        "pab_memory": pab_memory,
        "normal_time": normal_time,
        "pab_time": pab_time,
    }


@torch.no_grad()
def test_latte():
    reset_memory()
    
    pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    model_memory = torch.cuda.memory_allocated() / 1024**3

    prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."

    # Warmup
    _ = pipe(
        prompt=prompt,
        video_length=16,
        num_inference_steps=2,
        generator=torch.Generator().manual_seed(31337),
    )

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    latent = pipe(
        prompt=prompt,
        num_inference_steps=50,
        video_length=16,
        generator=torch.Generator().manual_seed(31337),
        output_type="latent",
    ).frames
    end.record()
    torch.cuda.synchronize()
    
    normal_time = start.elapsed_time(end) / 1000
    normal_memory = torch.cuda.max_memory_reserved() / 1024**3
    
    video = pipe.decode_latents(latent, video_length=16)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, "outputs/latte.mp4", fps=8)

    pipe.enable_pyramid_attention_broadcast(
        spatial_attn_skip_range=2,
        cross_attn_skip_range=6,
        spatial_attn_timestep_range=[100, 800],
        cross_attn_timestep_range=[100, 800],
    )

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    latent = pipe(
        prompt=prompt,
        video_length=16,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(31337),
        output_type="latent",
    ).frames
    end.record()
    torch.cuda.synchronize()
    
    pab_time = start.elapsed_time(end) / 1000
    pab_memory = torch.cuda.max_memory_reserved() / 1024**3

    video = pipe.decode_latents(latent, video_length=16)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, "outputs/latte_pab.mp4", fps=8)

    return {
        "model_memory": model_memory,
        "normal_memory": normal_memory,
        "pab_memory": pab_memory,
        "normal_time": normal_time,
        "pab_time": pab_time,
    }


for fn in [test_cogvideox_5b, test_cogvideox_5b_i2v, test_latte]:
    print(f"Running {fn.__name__}")
    results = fn()
    pretty_print_results(results)
    print()
  • For CogVideoX, we are generating 49 frames at 720 x 480 resolution
  • For Latte, we are generating 16 frames at 512 x 512 resolution

Following are the benchmarks based on the above script (note that no memory optimizations are enabled):

model model_memory normal_memory pab_memory normal_time pab_time speedup
Cog-2b T2V 12.55 35.342 35.342 86.915 63.914 1.359
Cog-5b T2V 19.66 40.945 40.945 246.152 168.59 1.460
Cog-5b I2V 19.764 42.74 42.74 246.867 170.111 1.451
Latte 11.007 25.594 25.594 28.026 24.073 1.164
CogVideoX-2b T2V
CogVideoX-5b T2V
CogVideoX-5b I2V
Latte

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@yiyixuxu @sayakpaul

@oahzxl for PAB, @zRzRzRzRzRzRzR for CogVideoX related changes, @maxin-cn for Latte related changes

a-r-r-o-w avatar Oct 01 '24 01:10 a-r-r-o-w

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

I can't seem to replicate the results for PAB on CogVideoX-5b T2V or I2V. This is what I get:

CogVideoX-5b T2V
CogVideoX-5b I2V

@oahzxl Would you be able to give this a review when free? I'm unable to figure out what I'm doing wrong that's causing poor results in these cases. Thank you!

a-r-r-o-w avatar Oct 03 '24 07:10 a-r-r-o-w

sure, thanks for your code! i guess it may be related with pos embed or encoder concat of 5b model. i can have a look at the code soon!

oahzxl avatar Oct 03 '24 07:10 oahzxl

hi, i have done some experiments and here are my conclusions:

i first try a simple implementation

the org attention is:

        attn_hidden_states, attn_encoder_hidden_states = self.attn1(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            image_rotary_emb=image_rotary_emb,
        )

for simplicty, i just add pab's logic here:

        # in init
        self.attn_count = 0
        self.last_attn = None
        
        ...

        # in forward
        if (10 < self.attn_count < 45) and (self.attn_count % 2 != 0):
            attn_hidden_states, attn_encoder_hidden_states = self.last_attn
        else:
            attn_hidden_states, attn_encoder_hidden_states = self.attn1(
                hidden_states=norm_hidden_states,
                encoder_hidden_states=norm_encoder_hidden_states,
                image_rotary_emb=image_rotary_emb,
            )
            self.last_attn = attn_hidden_states, attn_encoder_hidden_states

this should be exactly the same as the logic in pab processor.

then i find pab will be numerically unstable with fp16 for cogvideox-5b. so i change to bfloat16, and it works!

https://github.com/user-attachments/assets/28e6c4d7-0461-4c87-a869-e465ca3e72d1

->> so the first problem is float16!

  1. then i test pab processor

but fail even if i use bfloat16

i find even i set spatial_attn_skip_range to 1 (which means no broadcast), it will also generate random noise.

->> so i think the second problem is in processor, but no clue for now

hope it can help you!

oahzxl avatar Oct 03 '24 15:10 oahzxl

Thank you so much for the investigation! I think I found the bug. This line checks if the processor signature supports a specific keyword arguments before passing them. In this case, since we replace the attention processor with PyramidAttentionBroadcastAttentionProcessor, which only has args and kwargs, it drops the image_rotary_emb kwargs necessary for generation. So, RoPE embeddings are not passed at all causing bad video.

a-r-r-o-w avatar Oct 03 '24 15:10 a-r-r-o-w

glad i can help :) !

oahzxl avatar Oct 03 '24 15:10 oahzxl

I will do some tests with Mochi, now that it is in, and push the relevant changes here if okay. If not, we could maybe do it in a separate PR

a-r-r-o-w avatar Nov 09 '24 16:11 a-r-r-o-w

I think hooks can work to actually capture/cache intermediate values. Not sure if a Pipeline Mixin is the way to go for enabling/disabling the cache.

There seem to be a few methods for cacheing in video (PAB, AdaCache, FasterCache) etc. We might not need to support all of them, but it would be nice to be able to swap out cacheing mechanisms without having to add new Mixins to the Pipelines.

IMO we enable/disable the cache at the model level. e.g.

from diffusers import CogVideoXTransformer3DModel

model = CogVideoXTransformer3DModel.from_pretrained("...")
# method name TBD
model.enable_cached_inference(cache_type="pyramind_attention_broadcast", cache_kwargs={})

So we could create something like a CachedInferenceMixin that can apply the appropriate cache to the model and we add a class attribute in the model for supported cache types.

Enabling cacheing at the model level also lets us apply different cache mechanisms to different components. e.g

from diffusers import AutoencoderKLCogVideoX

model = AutoencoderKLCogVideoX.from_pretrained("...")
model.enable_cached_inference(cache_type="context_parallel")

Although I'm not sure how easy it would be to rewrite context parallel cache with hooks?

DN6 avatar Nov 11 '24 10:11 DN6

CachedInferenceMixin is more preferable to me, too! Do you envision this as a base class that other caching mechanisms would have to subclass from?

I think we will have a better idea of the common methods each caching mechanism needs to implement and have them in the base CachedInferenceMixin class and override as needed.

Thoughts? @DN6

sayakpaul avatar Nov 11 '24 11:11 sayakpaul

Jotting some points based on Slack convos and in-person convos.

It'd be prudent to consider an offloading mechanism for the caching utilities we are planning similar to how it's done in transformers. When CUDA is available, this should also be done with CUDA streams so as to be able to overlap computations and communications.

This can be made compatible with torch.compile() too as long as we're not mutating the internal model-level states. https://gist.github.com/gau-nernst/9408e13c32d3c6e7025d92cce6cba140 gives us an amazing example of making it all work with torch.compile().

sayakpaul avatar Dec 04 '24 09:12 sayakpaul

@DN6 Addressed the review comments. Could you give this another look?

a-r-r-o-w avatar Jan 16 '25 21:01 a-r-r-o-w

We need to make some more updates here before merging to address the case of using multiple hooks at once. The current implementation does not really work, if say both FP8 and PAB are enabled together. I will take it up in this PR before merging after layerwise upcasting is merged: https://github.com/huggingface/diffusers/pull/10347

This has already been addressed in group offloading PR but that will take some more time to complete: https://github.com/huggingface/diffusers/pull/10503

a-r-r-o-w avatar Jan 21 '25 14:01 a-r-r-o-w

With the latest changes, it is now possible to use multiple forward-modifying hooks now. Here's an example with FP8 layerwise-upcasting and PAB:

import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug

set_verbosity_debug()

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

config = PyramidAttentionBroadcastConfig(
    spatial_attention_block_skip_range=2,
    spatial_attention_timestep_skip_range=(150, 700),
    current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
pipe.transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)

a-r-r-o-w avatar Jan 22 '25 21:01 a-r-r-o-w

I think we're good to merge now and also got the approval from Dhruv after working together on latest changes! Thanks for the patience and the reviews everyone :hugs: Will merge once CI is green and wrap up the open cache PRs

@oahzxl Congratulations on the success of your new work - Data centric parallel! I also really liked reading about the pyramid activation checkpointing that was introduced in VideoSys. Thanks for your patience and help, and also for your work that inspired multiple other papers researching caching mechanism specific to video models. We will be sure to integrate as much as possible to make the methods more easily accessible :)

a-r-r-o-w avatar Jan 27 '25 22:01 a-r-r-o-w