tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[Question] On the condition of causal mask

Open lemyx opened this issue 6 months ago • 1 comments

Required prerequisites

Questions

Dear authors,

I notice that the causal mask condition in https://github.com/tile-ai/tilelang/blob/7d389a439106b57f09faca45dd7273de849a6a9c/examples/flash_attention/example_gqa_fwd_varlen.py#L132 seems to be wired.

if is_causal:
    for i, j in T.Parallel(block_M, block_N):
        acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
                                     (bx * block_M + i >= q_current_seqlen or
                                      k * block_N + j >= k_current_seqlen),
                                     -T.infinity(acc_s.dtype), 0)

I think the correct condition should be

if is_causal:
    for i, j in T.Parallel(block_M, block_N):
        acc_s[i, j] = T.if_then_else((bx * block_M + i < k * block_N + j) or
                                     (bx * block_M + i >= q_current_seqlen or
                                      k * block_N + j >= k_current_seqlen),
                                     -T.infinity(acc_s.dtype), 0)

How to understand it?

Best regards

lemyx avatar Oct 28 '25 03:10 lemyx

This is indeed a bug. Really appreciate your feedback, we'll fix this later

Rachmanino avatar Oct 28 '25 09:10 Rachmanino

closed as now fixed

LeiWang1999 avatar Nov 20 '25 12:11 LeiWang1999