tilelang
tilelang copied to clipboard
[Question] On the condition of causal mask
Required prerequisites
- [x] I have read the documentation https://tilelang.com.
- [x] I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
Questions
Dear authors,
I notice that the causal mask condition in https://github.com/tile-ai/tilelang/blob/7d389a439106b57f09faca45dd7273de849a6a9c/examples/flash_attention/example_gqa_fwd_varlen.py#L132 seems to be wired.
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
I think the correct condition should be
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i < k * block_N + j) or
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
How to understand it?
Best regards
This is indeed a bug. Really appreciate your feedback, we'll fix this later
closed as now fixed