[Bug] Why is VSA slower than Flash_attention when I run the training script examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/T2V-14B-VSA.slurm?
Describe the bug
The phenomenon is as follows:
vsa
When using VSA, each step takes 11.91 seconds.
Fa
thank you very much, Looking forward to your reply
Reproduction
examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/T2V-14B-VSA.slurm , Only the environment variable FASTVIDEO_ATTENTION_BACKEND was modified.
Environment
GPU: L40s cuda: 12.8 Driver Version: 535.230.02
If you’re using sparsity decay, then at the beginning of training sparsity is zero, so the model computes full attention, which is typically slower than a FlashAttention implementation.
If you’re using sparsity decay, then at the beginning of training sparsity is zero, so the model computes full attention, which is typically slower than a FlashAttention implementation.
Thank you very much for your answer.