[BUG] FMHA fwd kernel causal misbehaves when qlen != klen
Describe the bug The Blackwell FMHA fwd kernel causal mask implementation doesn't behave correctly when qlen != klen.
Steps/Code to reproduce bug Unfortunately the FMHA verifier is also buggy because it reuses the same wrong Mask implementation. But something like ./77_blackwell_fmha_fp16 --b=1 --h=1 --d=128 --q=3 --k=128 --b=1 --verify --mask=causal --verify should reveal the issue if the reference implementation is correct.
Expected behavior Referenced from FA repo README (https://github.com/Dao-AILab/flash-attention):
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
Environment details (please complete the following information):
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]
Additional context Add any other context about the problem here.
https://github.com/NVIDIA/cutlass/blob/main/examples/77_blackwell_fmha/collective/fmha_fusion.hpp#L180
explains this, and what to do about it. Both are imo useful settings, we just make a different choice here. Do you think picking the other way as a default would be useful?
See https://github.com/flashinfer-ai/flashinfer/pull/1039/commits/676e2d2d920d3fb0fa011285ae5e4998a674d83b for an way to actually implement it.
I don't quite understand "Q is at the beginning / end of the matrix" in the comments, do you mean aligning causal masks to the upper left corner v.s. bottom right corner?
I feel picking the other way would be useful since it's like the standard way implemented by all attention APIs.
I don't quite understand "Q is at the beginning / end of the matrix" in the comments, do you mean aligning causal masks to the upper left corner v.s. bottom right corner?
right, that's not worded very well. i mean that the mask, in q dimension, aligns with the start or end of the matrix (i.er. upper left/bottom right corner).
I am a little worried to change the existing causal mask code (same interface, different behavior), but I agree there should be a way to do the "standard" way. Would introducing a different mask class, e.g. CausalForInference that implements that be helpful?
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.