diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

torch 2.5 CuDNN backend for SDPA NaN error

Open wtyuan96 opened this issue 1 year ago • 2 comments

Describe the bug

When using the recently released PyTorch 2.5, the default SDPA backend is CUDNN_ATTENTION. In the example's CogVideoX-lora training script, NaN gradients occur right at the first step. However, using other SDPA backends, such as FLASH_ATTENTION or EFFICIENT_ATTENTION, does not lead to NaN issues.

After some preliminary investigation, I found that this might be related to the transpose and reshape operations following the SDPA computation (see L1954).

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

some related issues and PRs: https://github.com/pytorch/pytorch/issues/134001 https://github.com/pytorch/pytorch/pull/134031 https://github.com/pytorch/pytorch/pull/138354

Furthermore, I discovered that other attention processors in attention_processor.py also utilize the same transpose and reshape operations, such as FluxAttnProcessor2_0, which could potentially lead to similar problems.

Reproduction

This issue can be reproduced by setting a breakpoint after gradient backward and then printing the gradients:

loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
accelerator.backward(loss)
print([[name, param.grad] for name, param in transformer.named_parameters() if param.requires_grad])

Change the default backend for SDPA to FLASH_ATTENTION or EFFICIENT_ATTENTION in attention_processor.py, and the NaN issue will not occur.

from torch.nn.attention import SDPBackend, sdpa_kernel                                                                                                                                             
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # or EFFICIENT_ATTENTION                                                                                                                                          
    hidden_states = F.scaled_dot_product_attention(                                                                                                                                                
        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False                                                                                                                
    )                                                                                                                                                                                              
                                                                                                                                                                                                   
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

Considering that PyTorch 2.5 is currently the default version available for installation, this issue may require some attention.

Logs

No response

System Info

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-5.10.134-010
  • Running on Google Colab?: No
  • Python version: 3.10.15
  • PyTorch version (GPU?): 2.5.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.1
  • Transformers version: 4.46.0
  • Accelerate version: 1.0.1
  • PEFT version: 0.13.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB NVIDIA H20, 97871 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@DN6 @a-r-r-o-w @yiyixuxu @sayakpaul

wtyuan96 avatar Oct 25 '24 04:10 wtyuan96

I think this is a bug of pytorch. they are working on this to fix this bug. https://github.com/pytorch/pytorch/pull/138354#issue-2598184802

Before they fix we should use pytorch <2.5.0

wangyanhui666 avatar Oct 25 '24 09:10 wangyanhui666

Related? https://github.com/huggingface/diffusers/issues/9704

sayakpaul avatar Oct 25 '24 14:10 sayakpaul

I think this is a bug of pytorch. they are working on this to fix this bug. pytorch/pytorch#138354 (comment)

Before they fix we should use pytorch <2.5.0

Yes, or disable the CuDNN attention backend for SDPA.

wtyuan96 avatar Oct 26 '24 02:10 wtyuan96

Related? #9704

Yes, let's look forward to the PyTorch team fixing this issue in future versions of PyTorch.

wtyuan96 avatar Oct 26 '24 02:10 wtyuan96

i tried pytorch 2.4.1 also have this bug. so maybe disable the CuDNN attention backend in training code is a good solution.

wangyanhui666 avatar Oct 26 '24 17:10 wangyanhui666

Thanks so much for the detailed issue, once again.

We have pinned torch version: https://github.com/huggingface/diffusers/blob/c75431843f3b5b4915a57fe68a3e5420dc46a280/setup.py#L133

However, I think this is likely fixed with Torch 2.5.1 as cuDNN backend isn't selected as the SDPA backend by default. Could you give this a try?

sayakpaul avatar Nov 01 '24 03:11 sayakpaul

Thanks so much for the detailed issue, once again.

We have pinned torch version:

https://github.com/huggingface/diffusers/blob/c75431843f3b5b4915a57fe68a3e5420dc46a280/setup.py#L133

However, I think this is likely fixed with Torch 2.5.1 as cuDNN backend isn't selected as the SDPA backend by default. Could you give this a try?

Yes, Torch 2.5.1 puts the CuDNN backend as the lowest precedence in the backend list. I have tested torch 2.5.1, and it no longer reports NaN gradients aforementioned.

wtyuan96 avatar Nov 02 '24 06:11 wtyuan96