tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

autotune fails, but manually set block_M, block_N, block_K would work

Open a1600012888 opened this issue 6 months ago • 4 comments

Hi, tilelang team, thanks for such amazing work.

I am writing some simple fused GEMM kernel, by manually selecting the block_M, block_N, block_K, it compiles and works. But when turn on the autotune, it fails on H100.

Here is the tilelang code:


import torch
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
from tilelang.autotuner import AutoTuner, autotune
import itertools


device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor


def get_configs():
    block_M_list = [64, 128]
    block_N_list = [64, 128, 256]
    block_K_list = [32, 64]
    num_stages_list = [2, 3]
    thread_num_list = [256]
    _configs = list(
        itertools.product(
            block_M_list,
            block_N_list,
            block_K_list,
            num_stages_list,
            thread_num_list,
        )
    )
    config_dict_list = []
    for c in _configs:
        config_dict_list.append(
            {
                "block_M": c[0],
                "block_N": c[1],
                "block_K": c[2],
                "threads": c[4],
                "num_stages": c[3],
            }
        )
    return config_dict_list


@autotune(configs=get_configs(), warmup=10, rep=10, timeout=100000)
@tilelang.jit(out_idx=[-2, -1])
def fused_two_mm_same_inp(
    BatchSize,
    M,
    N,
    K,
    block_M=128,
    block_N=128,
    block_K=64,
    threads=256,
    num_stages=3,
    w_dtype="bfloat16",
):

    x_dtype = "bfloat16"
    num_tiles = BatchSize * T.ceildiv(M, block_M) * T.ceildiv(N, block_N)
    sm_num = driver.get_num_sms()

    @T.prim_func
    def _main(
        W0: T.Tensor((BatchSize, M, K), w_dtype),
        W2: T.Tensor(
            (BatchSize, M, K), w_dtype
        ),  # note: M,K layout (we'll use transpose_A=True)
        X: T.Tensor((BatchSize, K, N), x_dtype),
        O1: T.Tensor((BatchSize, M, N), x_dtype),
        O2: T.Tensor((BatchSize, M, N), x_dtype),
    ):

        with T.Kernel(sm_num, threads=threads) as (block_id):

            W0_sh = T.alloc_shared((block_M, block_K), dtype="bfloat16")
            W2_sh = T.alloc_shared((block_M, block_K), dtype="bfloat16")
            X_sh = T.alloc_shared((block_K, block_N), dtype="bfloat16")

            # Y2_sh = T.alloc_shared((block_M, block_N), dtype="bfloat16")

            T.use_swizzle(panel_size=10, enable=True)
            # T.annotate_layout(
            #     {
            #         # W0_sh: make_mma_swizzle_layout(W0_sh),
            #         # W2_sh: make_mma_swizzle_layout(W2_sh),
            #         # X_sh: make_mma_swizzle_layout(X_sh),
            #         O_sh: make_mma_swizzle_layout(O_sh),
            #     }
            # )

            Y0_loc = T.alloc_fragment((block_M, block_N), dtype="float32")
            Y2_loc = T.alloc_fragment((block_M, block_N), dtype="float32")

            for bb, bx, by in T.Persistent(
                [BatchSize, T.ceildiv(M, block_M), T.ceildiv(N, block_N)],
                wave_size=sm_num,
                index=block_id,
            ):

                m0 = bx * block_M
                n0 = by * block_N

                T.clear(Y0_loc)
                T.clear(Y2_loc)

                for bk in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    # Cast-on-copy to matmul_dtype
                    T.copy(W0[bb, m0, bk * block_K], W0_sh)  # [block_M, block_K]
                    T.copy(W2[bb, m0, bk * block_K], W2_sh)  # [block_M, block_K]
                    T.copy(X[bb, bk * block_K, n0], X_sh)  # [block_N, block_K]
                    T.gemm(W0_sh, X_sh, Y0_loc)  # (M,K) @ (N,K)^T -> (M,N)
                    T.gemm(W2_sh, X_sh, Y2_loc)

                T.copy(Y0_loc, O1[bb, m0, n0])
                T.copy(Y2_loc, O2[bb, m0, n0])

    return _main


def two_mm_same_inp_interface(W0, W2, X, autotune=True):
    """
    Args:
        W0, W2: [B, M, K]
        X:      [M, K, N]
    Outs:
        O1 = W0 @ X
        O2 = W2 @ X
    """
    BatchSize, M, K = W0.shape
    w_dtype = str(W0.dtype).split(".")[-1]
    N = X.shape[-1]

    if sm_version in {90}:  # H100 has 228 KB SMEM.
        BLOCK_M = 128
        BLOCK_N = 128
        BLOCK_K = 64
        threads = 256
        num_stages = 3
    else:  # A100 has 192KB SMEM, need to reduce the block_K
        BLOCK_M = 128
        BLOCK_N = 128
        BLOCK_K = 64
        threads = 256
        num_stages = 3

    if not autotune:
        kernel = fused_two_mm_same_inp(
            BatchSize,
            M,
            N,
            K,
            BLOCK_M,
            BLOCK_N,
            BLOCK_K,
            threads,
            num_stages,
            w_dtype=w_dtype,
        )
    else:
        kernel = fused_two_mm_same_inp(
            BatchSize,
            M,
            N,
            K,
            w_dtype=w_dtype,
        )

    O1, O2 = kernel(W0.contiguous(), W2.contiguous(), X.contiguous())

    return O1, O2


if __name__ == "__main__":
def make_inputs(B, H, D, L):
    W0 = torch.randn(
        B, H, L, device=device, dtype=torch.bfloat16, requires_grad=True
    )
    W2 = torch.randn(
        B, H, L, device=device, dtype=torch.bfloat16, requires_grad=True
    )
    X = torch.randn(
        B, L, D, device=device, dtype=torch.bfloat16, requires_grad=True
    )
    return W0, W2, X

W0, W2, X = make_inputs(4, 256, 256, 2048)
O1, O2 = two_mm_same_inp_interface(W0, W2, X)

by turning off the autotune, everything works, but when enables autotune, it says:

  File "/sensei-fs/users/tianyuanz/projects/longlact/longlact/tile_lang/swiglu_ffn_kernels.py", line 567, in two_mm_same_inp_interface
    kernel = fused_two_mm_same_inp(
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 692, in wrapper
    artifact = autotuner.run()
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 530, in run
    raise RuntimeError(error_msg)
RuntimeError: Auto-tuning failed: No configuration successfully compiled and passed benchmarking/validation

And the autotune.log looks like:

2025-10-05 21:00:05,147 INFO:Auto-tuning with 0.9 CPU utilizations, 192 CPUs available, 172 CPUs will be used
2025-10-05 21:00:21,560 WARNING:Tunable parameters ['block_M', 'block_N', 'block_K', 'threads', 'num_stages'] already provided during auto-tuning. Skipping compilation and using direct JIT
2025-10-05 21:00:36,009 INFO:Auto-tuning with 0.9 CPU utilizations, 192 CPUs available, 172 CPUs will be used
2025-10-05 21:00:52,554 INFO:An error occurred while testing config {'block_M': 64, 'block_N': 128, 'block_K': 64, 'threads': 256, 'num_stages': 3}, checkout autotuner.log for more details
2025-10-05 21:00:52,554 DEBUG:Error: Traceback (most recent call last):
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 503, in run
    latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 45, in run_with_timeout
    signal.signal(signal.SIGALRM, timeout_handler)
  File "/usr/lib/python3.10/signal.py", line 56, in signal
    handler = _signal.signal(_enum_to_int(signalnum), _enum_to_int(handler))
ValueError: signal only works in main thread of the main interpreter

2025-10-05 21:00:52,554 INFO:An error occurred while testing config {'block_M': 64, 'block_N': 256, 'block_K': 32, 'threads': 256, 'num_stages': 3}, checkout autotuner.log for more details
2025-10-05 21:00:52,555 DEBUG:Error: Traceback (most recent call last):
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 503, in run
    latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 45, in run_with_timeout
    signal.signal(signal.SIGALRM, timeout_handler)
  File "/usr/lib/python3.10/signal.py", line 56, in signal
    handler = _signal.signal(_enum_to_int(signalnum), _enum_to_int(handler))
ValueError: signal only works in main thread of the main interpreter

2025-10-05 21:00:52,555 INFO:An error occurred while testing config {'block_M': 64, 'block_N': 64, 'block_K': 64, 'threads': 256, 'num_stages': 2}, checkout autotuner.log for more details
2025-10-05 21:00:52,556 DEBUG:Error: Traceback (most recent call last):
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 503, in run
    latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 45, in run_with_timeout
    signal.signal(signal.SIGALRM, timeout_handler)
  File "/usr/lib/python3.10/signal.py", line 56, in signal
    handler = _signal.signal(_enum_to_int(signalnum), _enum_to_int(handler))
ValueError: signal only works in main thread of the main interpreter

..... [everything below looks the same]

I tried to increase the timeout, but it does not work.

a1600012888 avatar Oct 06 '25 04:10 a1600012888

Hi @a1600012888, I cannot reproduce your error with the latest TileLang by directly running your script on Hopper. Could you please provide more detail about your environment configuration and how you've run the script?

Rachmanino avatar Oct 06 '25 08:10 Rachmanino

Hi @Rachmanino, thanks for the support. I looked it more closely, and realized that this autotune only fails if it's called in the backward pass of some pytorch function.

Here is the full scripts to replicate the errors:

swiglu_ffn_kernels.py swiglu_ffn.py

You can put the two python file into the same folder, then run python3 swiglu_ffn.py, it will throw an error in the backward pass.

The autotuner.log will looks like this:


2025-10-06 09:50:40,398 INFO:An error occurred while testing config {'block_M': 128, 'block_N': 256, 'block_K': 32, 'threads': 128, 'num_stages': 2}, checkout autotuner.log for more details
2025-10-06 09:50:40,398 DEBUG:Error: Traceback (most recent call last):
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 503, in run
    latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
  File "/opt/venv/lib/python3.10/site-packages/tilelang/autotuner/tuner.py", line 45, in run_with_timeout
    signal.signal(signal.SIGALRM, timeout_handler)
  File "/usr/lib/python3.10/signal.py", line 56, in signal
    handler = _signal.signal(_enum_to_int(signalnum), _enum_to_int(handler))
ValueError: signal only works in main thread of the main interpreter

And I looked more, one reason might be that the backward pass is called by pytorch from another thread (not main thread), thus cannot do autotune.

a1600012888 avatar Oct 06 '25 16:10 a1600012888

Updates, seems more clearly for me that if I write a kernel that would be called during the backward pass of pytorch functions, then it will cause such error, and I used some monkey patching to work around this:

import threading
import tilelang.autotuner.tuner as _tuner

_orig_rwt = _tuner.run_with_timeout


def _safe_run_with_timeout(target_fn, timeout, *args, **kwargs):
    if threading.current_thread() is threading.main_thread():
        return _orig_rwt(target_fn, timeout, *args, **kwargs)
    # Fallback: no SIGALRM in worker threads; just run the function directly.
    return target_fn(*args, **kwargs)


_tuner.run_with_timeout = _safe_run_with_timeout

a1600012888 avatar Oct 06 '25 17:10 a1600012888

@a1600012888

Thanks for your monkey patch.

I have met the same issue of ValueError: signal only works in main thread of the main interpreter only for the backward kernel, while the forward kernel can be autotuned succesfully.

lemyx avatar Nov 18 '25 09:11 lemyx