[BUG] Auto Vectorization on AtomicAdd Ignores BufferLayout and LowerArgs
Required prerequisites
- [x] I have read the documentation https://tilelang.com.
- [x] I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post2+cuda.git14afc718
System information
3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux 0.1.6.post2+cuda.git14afc718 2.7.0+cu128
Problem description
The loop planner in AtomicAddNode::Lower ignores the layout of other buffers, which may lead to invalid vectorization on a fragment buffer.
https://github.com/tile-ai/tilelang/blob/d88594a32a52a46b5dc09ca9f17ed7f22569d179/src/op/atomic_add.cc#L486-L489
@yyttt6
To reproduce the bug, checkout pr 1367 and run attention_sink in examples
gh pr checkout 1367
python examples/attention_sink/example_gqa_sink_bwd_bhsd.py
This bug is because the atomic planner ignores the buffer layout of the global buffer dQ; the global buffer dQ has a specialized layout. We should check both the src and dst buffers when doing vectorization.
Reproduce Step
- In
phase.py,dump IR
# Lower high-level tile operations to low-level operations
mod = tilelang.transform.LowerTileOp()(mod)
mod = tilelang.transform.Simplify()(mod)
print(mod)
# Lower l2 persistent map
mod = tilelang.transform.LowerL2Persistent()(mod)
# Legalize vectorized loops to ensure they are valid
- Reproduce script
import tilelang
import tilelang.language as T
import torch
def get_bwd_configs():
sm_major, sm_minor = torch.cuda.get_device_capability()
sm_version = sm_major * 10 + sm_minor
if sm_version == 80:
return 64, 32, 1, 128
elif sm_version == 90:
return 128, 32, 2, 256
else:
raise ValueError(f"Unsupported SM version: {sm_version}")
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape,
lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch,
heads,
seq_len,
dim,
groups,
window_size=None,
sm_scale=None,
dtype="float16"): # None for full attention
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = "float"
block_M, block_N, num_stages, threads = get_bwd_configs()
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore
dO: T.Tensor(q_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(kv_shape, accum_dtype), # type: ignore
dV: T.Tensor(kv_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, bx // groups, by * block_M:(by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx // groups, by * block_M:(by + 1) * block_M, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
for i, j in T.Parallel(block_M, block_N):
if window_size is not None:
qkT[i, j] = T.if_then_else(
by * block_M + i <= k * block_N + j and
by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0)
else:
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq)
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared)
return flash_bwd
if __name__ == '__main__':
args = {
'batch': 1,
'heads': 64,
'seq_len': 4096,
'dim': 128,
'groups': 8,
'window_size': None,
'sm_scale': None,
'dtype': 'float16'
}
flashattn_bwd(**args)
Wrong Output
for i, j in T.grid(1, 2):
T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_local.data, 0, B_local.data, j * 8, dq.data, j * 8, T.bool(False))
T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_local.data, 0, B_local.data, j * 8 + 4, dq.data, j * 8 + 4, T.bool(False))
for i in T.unroll(8, annotations={"pragma_unroll_explicit": T.bool(False)}):
# HERE!
T.call_extern("float32", "AtomicAddx2", T.address_of(dQ_1[0, bx, k * 4 + thread_binding % 64 // 32 * 2 + i // 4, thread_binding // 64 * 4 + i % 4, 0, thread_binding % 32]), T.address_of(dq[i % 4 * 4 + i // 4 * 2]), 0)
for i in T.unroll(32, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(2):
dv_shared[i // 8, thread_binding // 32 * 2 + i % 2, thread_binding % 32 // 4 * 32 + ((i // 2 * 8 + thread_binding % 4 * 2 + vec) % 32 // 16 + thread_binding % 32 // 16) % 2 * 16 + (thread_binding % 16 // 8 + (i // 2 * 8 + thread_binding % 4 * 2 + vec) % 16 // 8) % 2 * 8 + (thread_binding % 8 // 4 + thread_binding % 4 // 2) % 2 * 4 + thread_binding % 2 * 2 + vec] = dv[i * 2 + vec]