tilelang
tilelang copied to clipboard
[Feature request] Support automatic upcasting
import torch
import tilelang
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def get_buggy_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, ), 'int64']):
with T.Kernel(num_tokens, threads=128) as pid:
a, b = T.alloc_var('int'), T.alloc_var('int')
b = x[pid] == a
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel()
print(kernel.get_kernel_source())
x = torch.zeros((128, ), dtype=torch.int64, device='cuda')
kernel(x)
Please notice the signness during upcasting.