tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[BUG] Dynamic shape compile hangs forever

Open a1600012888 opened this issue 4 months ago • 1 comments

Required prerequisites

What version of TileLang are you using?

0.1.6.post1+2d4b848fcb34d18a701331c87d5c575de530ebed

System information

Envs:

tilelang: 0.1.6.post1+2d4b848fcb34d18a701331c87d5c575de530ebed torch==2.6.0

A100 and H100 both has the same problem.

Problem description

Hi, Tilelang team, I am calling a previous simple matmul kernel with dynamic shapes, and it fails to compile (hangs forever). Given static shape, the code works fine.

Moreover, I find a wired correlation: I was playing with my previous several kernels with dynamic shapes, I found that if the kernel has more than one output tensor, then the dynamic shape compile would hang.

I provided the code to replicate below:

Reproducible example code

The Python snippets:

import torch
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver

from tilelang.autotuner import AutoTuner, autotune
import itertools

from tilelang.intrinsics import make_mma_swizzle_layout

import threading
import tilelang.autotuner.tuner as _tuner


_orig_rwt = _tuner.run_with_timeout


def _safe_run_with_timeout(target_fn, timeout, *args, **kwargs):
    if threading.current_thread() is threading.main_thread():
        return _orig_rwt(target_fn, timeout, *args, **kwargs)
    # Fallback: no SIGALRM in worker threads; just run the function directly.
    return target_fn(*args, **kwargs)


_tuner.run_with_timeout = _safe_run_with_timeout


device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor


@tilelang.jit(
    out_idx=[3, 4],
    pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
)
def fused_two_mm_same_inp(
    BatchSize,
    M,
    N,
    K,
    block_M=128,  # =128,
    block_N=128,  # =128,
    block_K=64,  # =64,
    threads=256,  # =256,
    num_stages=3,  # =3,
    w_dtype="bfloat16",
):
    """
    O1 = W0 @ X
    O2 = W2 @ X
    """

    x_dtype = "bfloat16"
    num_tiles = BatchSize * T.ceildiv(M, block_M) * T.ceildiv(N, block_N)
    sm_num = driver.get_num_sms()

    @T.prim_func
    def _main(
        W0: T.Tensor((BatchSize, M, K), w_dtype),
        W2: T.Tensor(
            (BatchSize, M, K), w_dtype
        ),  # note: M,K layout (we'll use transpose_A=True)
        X: T.Tensor((BatchSize, K, N), x_dtype),
        O1: T.Tensor((BatchSize, M, N), x_dtype),
        O2: T.Tensor((BatchSize, M, N), x_dtype),
    ):

        with T.Kernel(sm_num, threads=threads) as (block_id):

            W0_sh = T.alloc_shared((block_M, block_K), dtype="bfloat16")
            W2_sh = T.alloc_shared((block_M, block_K), dtype="bfloat16")
            X_sh = T.alloc_shared((block_K, block_N), dtype="bfloat16")

            # Y2_sh = T.alloc_shared((block_M, block_N), dtype="bfloat16")

            # T.use_swizzle(panel_size=10, enable=True)

            Y0_loc = T.alloc_fragment((block_M, block_N), dtype="float32")
            Y2_loc = T.alloc_fragment((block_M, block_N), dtype="float32")

            for bb, bx, by in T.Persistent(
                [BatchSize, T.ceildiv(M, block_M), T.ceildiv(N, block_N)],
                wave_size=sm_num,
                index=block_id,
            ):

                m0 = bx * block_M
                n0 = by * block_N

                T.clear(Y0_loc)
                T.clear(Y2_loc)

                for bk in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):

                    T.copy(W0[bb, m0, bk * block_K], W0_sh)  # [block_M, block_K]
                    T.copy(W2[bb, m0, bk * block_K], W2_sh)  # [block_M, block_K]
                    T.copy(X[bb, bk * block_K, n0], X_sh)  # [block_K, block_N]
                    T.gemm(W0_sh, X_sh, Y0_loc)  # (M,K) @ (N,K)^T -> (M,N)
                    T.gemm(W2_sh, X_sh, Y2_loc)

                T.copy(Y0_loc, O1[bb, m0, n0])
                T.copy(Y2_loc, O2[bb, m0, n0])

    return _main


@torch._dynamo.disable
def two_mm_same_inp_interface_v2(W0, W2, X):
    """
    Args:
        W0, W2:  [B, M, K], weight dtype can be bf16/fp16/etc. (cast-on-copy to bf16 for MMA)
        X:       [B, K, N]
    Outs:
        O1 = W0 @ X
        O2 = W2 @ X

        of shape [B, M, N]
    """

    BatchSize, M, K = W0.shape
    w_dtype = str(W0.dtype).split(".")[-1]

    N = X.shape[2]
    if sm_version in {90}:  # H100
        BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
        threads, num_stages = 256, 3
    else:  # A100 etc.
        BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
        threads, num_stages = 256, 2

    kernel = fused_two_mm_same_inp(
        T.symbolic("b"),
        # BatchSize,
        M,
        T.symbolic("n"),
        # N,
        K,
        BLOCK_M,
        BLOCK_N,
        BLOCK_K,
        threads,
        num_stages,
        w_dtype=w_dtype,
    )

    O1, O2 = kernel(W0.contiguous(), W2.contiguous(), X.contiguous())
    return O1, O2


if __name__ == "__main__":

    # Quick correctness test against PyTorch reference on CUDA.
    torch.manual_seed(0)
    device = "cuda"
    B, M, N, K = 2, 1024, 1024, 1024

    # Interface expects A_transpose=True (W: [B, K, M]) and B_transpose=False (X: [B, K, N])
    W0 = torch.randn(B, K, M, device=device, dtype=torch.bfloat16)
    W1 = torch.randn(B, K, M, device=device, dtype=torch.bfloat16)
    X0 = torch.randn(B, M, N, device=device, dtype=torch.bfloat16)

    O1, O2 = two_mm_same_inp_interface_v2(
        W0,
        W1,
        X0,
    )
    print(f"Output shapes: O1.shape={O1.shape}, O2.shape={O2.shape}")

Traceback

python3 test_example.py
2025-10-19 15:06:23  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `_main` with `out_idx=[3, 4]`

Expected behavior

No response

Additional context

No response

a1600012888 avatar Oct 19 '25 22:10 a1600012888

likely sth relevant to T.Persistent, cc @chengyupku

LeiWang1999 avatar Oct 20 '25 08:10 LeiWang1999