tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[Feature request] Support automatic upcasting

Open LyricZhao opened this issue 6 months ago • 0 comments

import torch
import tilelang
from tilelang import language as T


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },
)
def get_buggy_kernel():
    num_tokens = T.symbolic('num_tokens')

    @T.prim_func
    def buggy_kernel(x: T.Tensor[(num_tokens, ), 'int64']):
        with T.Kernel(num_tokens, threads=128) as pid:
            a, b = T.alloc_var('int'), T.alloc_var('int')
            b = x[pid] == a

    return buggy_kernel


if __name__ == '__main__':
    kernel = get_buggy_kernel()
    print(kernel.get_kernel_source())

    x = torch.zeros((128, ), dtype=torch.int64, device='cuda')
    kernel(x)

Please notice the signness during upcasting.

LyricZhao avatar Oct 13 '25 08:10 LyricZhao