[BUG] Dynamic shape compile hangs forever
Required prerequisites
- [x] I have read the documentation https://tilelang.com.
- [x] I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post1+2d4b848fcb34d18a701331c87d5c575de530ebed
System information
Envs:
tilelang: 0.1.6.post1+2d4b848fcb34d18a701331c87d5c575de530ebed torch==2.6.0
A100 and H100 both has the same problem.
Problem description
Hi, Tilelang team, I am calling a previous simple matmul kernel with dynamic shapes, and it fails to compile (hangs forever). Given static shape, the code works fine.
Moreover, I find a wired correlation: I was playing with my previous several kernels with dynamic shapes, I found that if the kernel has more than one output tensor, then the dynamic shape compile would hang.
I provided the code to replicate below:
Reproducible example code
The Python snippets:
import torch
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
from tilelang.autotuner import AutoTuner, autotune
import itertools
from tilelang.intrinsics import make_mma_swizzle_layout
import threading
import tilelang.autotuner.tuner as _tuner
_orig_rwt = _tuner.run_with_timeout
def _safe_run_with_timeout(target_fn, timeout, *args, **kwargs):
if threading.current_thread() is threading.main_thread():
return _orig_rwt(target_fn, timeout, *args, **kwargs)
# Fallback: no SIGALRM in worker threads; just run the function directly.
return target_fn(*args, **kwargs)
_tuner.run_with_timeout = _safe_run_with_timeout
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
@tilelang.jit(
out_idx=[3, 4],
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
)
def fused_two_mm_same_inp(
BatchSize,
M,
N,
K,
block_M=128, # =128,
block_N=128, # =128,
block_K=64, # =64,
threads=256, # =256,
num_stages=3, # =3,
w_dtype="bfloat16",
):
"""
O1 = W0 @ X
O2 = W2 @ X
"""
x_dtype = "bfloat16"
num_tiles = BatchSize * T.ceildiv(M, block_M) * T.ceildiv(N, block_N)
sm_num = driver.get_num_sms()
@T.prim_func
def _main(
W0: T.Tensor((BatchSize, M, K), w_dtype),
W2: T.Tensor(
(BatchSize, M, K), w_dtype
), # note: M,K layout (we'll use transpose_A=True)
X: T.Tensor((BatchSize, K, N), x_dtype),
O1: T.Tensor((BatchSize, M, N), x_dtype),
O2: T.Tensor((BatchSize, M, N), x_dtype),
):
with T.Kernel(sm_num, threads=threads) as (block_id):
W0_sh = T.alloc_shared((block_M, block_K), dtype="bfloat16")
W2_sh = T.alloc_shared((block_M, block_K), dtype="bfloat16")
X_sh = T.alloc_shared((block_K, block_N), dtype="bfloat16")
# Y2_sh = T.alloc_shared((block_M, block_N), dtype="bfloat16")
# T.use_swizzle(panel_size=10, enable=True)
Y0_loc = T.alloc_fragment((block_M, block_N), dtype="float32")
Y2_loc = T.alloc_fragment((block_M, block_N), dtype="float32")
for bb, bx, by in T.Persistent(
[BatchSize, T.ceildiv(M, block_M), T.ceildiv(N, block_N)],
wave_size=sm_num,
index=block_id,
):
m0 = bx * block_M
n0 = by * block_N
T.clear(Y0_loc)
T.clear(Y2_loc)
for bk in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(W0[bb, m0, bk * block_K], W0_sh) # [block_M, block_K]
T.copy(W2[bb, m0, bk * block_K], W2_sh) # [block_M, block_K]
T.copy(X[bb, bk * block_K, n0], X_sh) # [block_K, block_N]
T.gemm(W0_sh, X_sh, Y0_loc) # (M,K) @ (N,K)^T -> (M,N)
T.gemm(W2_sh, X_sh, Y2_loc)
T.copy(Y0_loc, O1[bb, m0, n0])
T.copy(Y2_loc, O2[bb, m0, n0])
return _main
@torch._dynamo.disable
def two_mm_same_inp_interface_v2(W0, W2, X):
"""
Args:
W0, W2: [B, M, K], weight dtype can be bf16/fp16/etc. (cast-on-copy to bf16 for MMA)
X: [B, K, N]
Outs:
O1 = W0 @ X
O2 = W2 @ X
of shape [B, M, N]
"""
BatchSize, M, K = W0.shape
w_dtype = str(W0.dtype).split(".")[-1]
N = X.shape[2]
if sm_version in {90}: # H100
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
threads, num_stages = 256, 3
else: # A100 etc.
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
threads, num_stages = 256, 2
kernel = fused_two_mm_same_inp(
T.symbolic("b"),
# BatchSize,
M,
T.symbolic("n"),
# N,
K,
BLOCK_M,
BLOCK_N,
BLOCK_K,
threads,
num_stages,
w_dtype=w_dtype,
)
O1, O2 = kernel(W0.contiguous(), W2.contiguous(), X.contiguous())
return O1, O2
if __name__ == "__main__":
# Quick correctness test against PyTorch reference on CUDA.
torch.manual_seed(0)
device = "cuda"
B, M, N, K = 2, 1024, 1024, 1024
# Interface expects A_transpose=True (W: [B, K, M]) and B_transpose=False (X: [B, K, N])
W0 = torch.randn(B, K, M, device=device, dtype=torch.bfloat16)
W1 = torch.randn(B, K, M, device=device, dtype=torch.bfloat16)
X0 = torch.randn(B, M, N, device=device, dtype=torch.bfloat16)
O1, O2 = two_mm_same_inp_interface_v2(
W0,
W1,
X0,
)
print(f"Output shapes: O1.shape={O1.shape}, O2.shape={O2.shape}")
Traceback
python3 test_example.py
2025-10-19 15:06:23 [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `_main` with `out_idx=[3, 4]`
Expected behavior
No response
Additional context
No response
likely sth relevant to T.Persistent, cc @chengyupku