cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] FMHA fwd kernel causal misbehaves when qlen != klen

Open ipiszy-x opened this issue 8 months ago • 3 comments

Describe the bug The Blackwell FMHA fwd kernel causal mask implementation doesn't behave correctly when qlen != klen.

Steps/Code to reproduce bug Unfortunately the FMHA verifier is also buggy because it reuses the same wrong Mask implementation. But something like ./77_blackwell_fmha_fp16 --b=1 --h=1 --d=128 --q=3 --k=128 --b=1 --verify --mask=causal --verify should reveal the issue if the reference implementation is correct.

Expected behavior Referenced from FA repo README (https://github.com/Dao-AILab/flash-attention):

    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
        1 1 1 1 0
        1 1 1 1 1
    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
        0 0
        0 0
        0 0
        1 0
        1 1

Environment details (please complete the following information):

  • Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]

Additional context Add any other context about the problem here.

ipiszy-x avatar May 26 '25 19:05 ipiszy-x

https://github.com/NVIDIA/cutlass/blob/main/examples/77_blackwell_fmha/collective/fmha_fusion.hpp#L180

explains this, and what to do about it. Both are imo useful settings, we just make a different choice here. Do you think picking the other way as a default would be useful?

See https://github.com/flashinfer-ai/flashinfer/pull/1039/commits/676e2d2d920d3fb0fa011285ae5e4998a674d83b for an way to actually implement it.

v0i0 avatar May 27 '25 15:05 v0i0

I don't quite understand "Q is at the beginning / end of the matrix" in the comments, do you mean aligning causal masks to the upper left corner v.s. bottom right corner?

I feel picking the other way would be useful since it's like the standard way implemented by all attention APIs.

ipiszy-x avatar May 27 '25 17:05 ipiszy-x

I don't quite understand "Q is at the beginning / end of the matrix" in the comments, do you mean aligning causal masks to the upper left corner v.s. bottom right corner?

right, that's not worded very well. i mean that the mask, in q dimension, aligns with the start or end of the matrix (i.er. upper left/bottom right corner).

I am a little worried to change the existing causal mask code (same interface, different behavior), but I agree there should be a way to do the "standard" way. Would introducing a different mask class, e.g. CausalForInference that implements that be helpful?

v0i0 avatar May 28 '25 00:05 v0i0

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jun 27 '25 01:06 github-actions[bot]

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Sep 25 '25 03:09 github-actions[bot]