tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[BUG] Auto Vectorization on AtomicAdd Ignores BufferLayout and LowerArgs

Open kurisu6912 opened this issue 4 months ago • 1 comments

Required prerequisites

What version of TileLang are you using?

0.1.6.post2+cuda.git14afc718

System information

3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux 0.1.6.post2+cuda.git14afc718 2.7.0+cu128

Problem description

The loop planner in AtomicAddNode::Lower ignores the layout of other buffers, which may lead to invalid vectorization on a fragment buffer.

https://github.com/tile-ai/tilelang/blob/d88594a32a52a46b5dc09ca9f17ed7f22569d179/src/op/atomic_add.cc#L486-L489

kurisu6912 avatar Dec 02 '25 10:12 kurisu6912

@yyttt6

LeiWang1999 avatar Dec 02 '25 10:12 LeiWang1999

To reproduce the bug, checkout pr 1367 and run attention_sink in examples

gh pr checkout 1367
python examples/attention_sink/example_gqa_sink_bwd_bhsd.py

kurisu6912 avatar Dec 05 '25 09:12 kurisu6912

This bug is because the atomic planner ignores the buffer layout of the global buffer dQ; the global buffer dQ has a specialized layout. We should check both the src and dst buffers when doing vectorization.

Reproduce Step

  1. In phase.py,dump IR
    # Lower high-level tile operations to low-level operations
    mod = tilelang.transform.LowerTileOp()(mod)
    mod = tilelang.transform.Simplify()(mod)
    print(mod)
    # Lower l2 persistent map
    mod = tilelang.transform.LowerL2Persistent()(mod)
    # Legalize vectorized loops to ensure they are valid
  1. Reproduce script
import tilelang
import tilelang.language as T
import torch

def get_bwd_configs():
    sm_major, sm_minor = torch.cuda.get_device_capability()
    sm_version = sm_major * 10 + sm_minor
    if sm_version == 80:
        return 64, 32, 1, 128
    elif sm_version == 90:
        return 128, 32, 2, 256
    else:
        raise ValueError(f"Unsupported SM version: {sm_version}")

def make_dq_layout(dQ):
    # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
    return T.Layout(dQ.shape,
                    lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])

@tilelang.jit(pass_configs={
    tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch,
                  heads,
                  seq_len,
                  dim,
                  groups,
                  window_size=None,
                  sm_scale=None,
                  dtype="float16"):  # None for full attention
    if sm_scale is None:
        sm_scale = (1.0 / dim)**0.5
    scale = sm_scale * 1.44269504  # log2(e)

    head_kv = heads // groups
    q_shape = [batch, heads, seq_len, dim]
    kv_shape = [batch, head_kv, seq_len, dim]
    accum_dtype = "float"

    block_M, block_N, num_stages, threads = get_bwd_configs()

    if window_size is not None:
        assert window_size % block_N == 0, "window_size must be divisible by block_N"

    @T.prim_func
    def flash_bwd(
            Q: T.Tensor(q_shape, dtype),  # type: ignore
            K: T.Tensor(kv_shape, dtype),  # type: ignore
            V: T.Tensor(kv_shape, dtype),  # type: ignore
            dO: T.Tensor(q_shape, dtype),  # type: ignore
            lse: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
            Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
            dQ: T.Tensor(q_shape, accum_dtype),  # type: ignore
            dK: T.Tensor(kv_shape, accum_dtype),  # type: ignore
            dV: T.Tensor(kv_shape, accum_dtype),  # type: ignore
    ):
        with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
            K_shared = T.alloc_shared([block_M, dim], dtype)
            dsT_shared = T.alloc_shared([block_M, block_N], dtype)
            q = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_M, dim], dtype)
            qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
            dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
            qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
            dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
            lse_shared = T.alloc_shared([block_N], accum_dtype)
            delta = T.alloc_shared([block_N], accum_dtype)
            do = T.alloc_shared([block_N, dim], dtype)
            dv = T.alloc_fragment([block_M, dim], accum_dtype)
            dk = T.alloc_fragment([block_M, dim], accum_dtype)
            dq = T.alloc_fragment([block_N, dim], accum_dtype)
            dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
            dk_shared = T.alloc_shared([block_M, dim], accum_dtype)

            T.annotate_layout({
                dQ: make_dq_layout(dQ),
                K_shared: tilelang.layout.make_swizzled_layout(K_shared),
                dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
                dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
            })
            T.copy(K[bz, bx // groups, by * block_M:(by + 1) * block_M, :], K_shared)
            T.copy(V[bz, bx // groups, by * block_M:(by + 1) * block_M, :], V_shared)
            T.clear(dv)
            T.clear(dk)

            loop_st = T.floordiv(by * block_M, block_N)
            loop_ed = T.min(
                T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
                    seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)

            for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
                T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
                T.clear(qkT)
                T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
                for i, j in T.Parallel(block_M, block_N):
                    qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
                for i, j in T.Parallel(block_M, block_N):
                    if window_size is not None:
                        qkT[i, j] = T.if_then_else(
                            by * block_M + i <= k * block_N + j and
                            by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0)
                    else:
                        qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
                                                   0)
                T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do)
                T.clear(dsT)
                T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(qkT, qkT_cast)
                T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)

                T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)

                for i, j in T.Parallel(block_M, block_N):
                    dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
                T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

                T.copy(dsT_cast, dsT_shared)
                T.clear(dq)
                T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
                T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq)

            T.copy(dv, dv_shared)
            T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared)
            T.copy(dk, dk_shared)
            T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared)

    return flash_bwd

if __name__ == '__main__':
    args = {
        'batch': 1,
        'heads': 64,
        'seq_len': 4096,
        'dim': 128,
        'groups': 8,
        'window_size': None,
        'sm_scale': None,
        'dtype': 'float16'
    }
    flashattn_bwd(**args)

Wrong Output

                            for i, j in T.grid(1, 2):
                                T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_local.data, 0, B_local.data, j * 8, dq.data, j * 8, T.bool(False))
                                T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_local.data, 0, B_local.data, j * 8 + 4, dq.data, j * 8 + 4, T.bool(False))
                    for i in T.unroll(8, annotations={"pragma_unroll_explicit": T.bool(False)}):
                        # HERE!
                        T.call_extern("float32", "AtomicAddx2", T.address_of(dQ_1[0, bx, k * 4 + thread_binding % 64 // 32 * 2 + i // 4, thread_binding // 64 * 4 + i % 4, 0, thread_binding % 32]), T.address_of(dq[i % 4 * 4 + i // 4 * 2]), 0)
                for i in T.unroll(32, annotations={"pragma_unroll_explicit": T.bool(False)}):
                    for vec in T.vectorized(2):
                        dv_shared[i // 8, thread_binding // 32 * 2 + i % 2, thread_binding % 32 // 4 * 32 + ((i // 2 * 8 + thread_binding % 4 * 2 + vec) % 32 // 16 + thread_binding % 32 // 16) % 2 * 16 + (thread_binding % 16 // 8 + (i // 2 * 8 + thread_binding % 4 * 2 + vec) % 16 // 8) % 2 * 8 + (thread_binding % 8 // 4 + thread_binding % 4 // 2) % 2 * 4 + thread_binding % 2 * 2 + vec] = dv[i * 2 + vec]

kurisu6912 avatar Dec 05 '25 10:12 kurisu6912