Which GPUs does this work on?
I'm assuming it only works on Ampere, Hopper, Lovelace. Is that correct? It might be nice to specify in the readme, if it is limited to certain GPU types.
Thank you @nbroad1881 for raising the issue! Since our kernels are completely triton based, they share the same compatibility as triton itself, which is the following as https://github.com/triton-lang/triton says.
Supported Platforms:
Linux
Supported Hardware:
NVIDIA GPUs (Compute Capability 7.0+)
AMD GPUs (ROCm 5.2+)
Under development: CPUs
We have only tested on Ampere and Hopper for now. We hope the community can help us test on broader spectrum in production. I will mark this as a TODO and gather the community help. Thanks!
Thanks!
@nbroad1881 Tested on RTX 3070 before and kernels work as expected, see https://github.com/linkedin/Liger-Kernel/pull/47. However, some of the large test/benchmarks may fail due to out of memory.
@ByronHsu Hi, thanks for the super amazing work!
I failed testing on T4 sm_75 (Turing), which does not support certain features like .bf16 (bfloat16) operations. These features require a GPU architecture of sm_80 (Ampere) or higher. Please check the following log for details. Let me know if I misunderstood anything. Appreciate it!
When I ran $ pytest ./test/transformers/test_geglu.py::test_correctness, it returned:
======================================= test session starts ========================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /content/Liger-Kernel
plugins: typeguard-4.3.0, anyio-3.7.1
collected 8 items
test/transformers/test_geglu.py ....FFFF [100%]
============================================= FAILURES =============================================
_____________________ test_correctness[dtype1-10000.0-0.006-2-2048-4096-11008] _____________________
bsz = 2, seq_len = 2048, hidden_size = 4096, intermediate_size = 11008, dtype = torch.bfloat16
atol = 10000.0, rtol = 0.006
@pytest.mark.parametrize(
"bsz, seq_len, hidden_size, intermediate_size",
[
(2, 2048, 4096, 11008),
(2, 2048, 2048, 4096),
# weird shapes
(9, 41, 341, 4231),
(6, 42, 256, 2048),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
# atol is for small values: they have more difference, so set atol higher
# rtol is for larger values: they are very close, so set rtol lower
(torch.float32, 1e-0, 2e-6),
(torch.bfloat16, 1e4, 6e-3),
],
)
def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol):
_input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype)
x1 = _input.clone().requires_grad_(True)
x2 = _input.clone().requires_grad_(True)
# initialize weights
G = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype)
U = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype)
D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype)
llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to("cuda").to(dtype)
llama_mlp.gate_proj.weight.data = G.T
llama_mlp.up_proj.weight.data = U.T
llama_mlp.down_proj.weight.data = D.T
liger_mlp = LigerGEGLUMLP(config=LLAMA_CONFIG).to("cuda").to(dtype)
liger_mlp.gate_proj.weight.data = G.T
liger_mlp.up_proj.weight.data = U.T
liger_mlp.down_proj.weight.data = D.T
y1 = llama_mlp(x1)
> y2 = liger_mlp(x2)
test/transformers/test_geglu.py:58:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541: in _call_impl
return forward_call(*args, **kwargs)
src/liger_kernel/transformers/geglu.py:22: in forward
LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:598: in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
src/liger_kernel/ops/utils.py:18: in wrapper
return fn(ctx, *args, **kwargs)
src/liger_kernel/ops/geglu.py:104: in forward
_geglu_tanh_forward_kernel[(n_rows,)](
/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:167: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:416: in run
self.cache[device][key] = compile(
/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py:193: in compile
next_module = compile_ir(module, metadata)
/usr/local/lib/python3.10/dist-packages/triton/compiler/backends/cuda.py:201: in <lambda>
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src = '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.2\n.target sm_75\n.address_size 64\n\n\t// .globl\t_geglu_t...pes_start0:\n.b8 2\n.b8 0\n.b32 .debug_info\n.b32 191\n.b32 0\n$L__pubTypes_end0:\n\t}\n\t.section\t.debug_loc\t{\t}\n'
metadata = {'AMDGCN_ENABLE_DUMP': False, 'DISABLE_FAST_REDUCTION': False, 'DISABLE_MMA_V3': False, 'ENABLE_TMA': False, ...}
opt = CUDAOptions(num_warps=16, num_ctas=1, num_stages=3, cluster_dims=(1, 1, 1), ptx_version=None, enable_warp_specializati...logue=False, enable_fp_fusion=True, allow_fp8e4nv=False, max_num_imprecise_acc_default=0, extern_libs=None, debug=None)
capability = 75
@staticmethod
def make_cubin(src, metadata, opt, capability):
metadata["name"] = get_kernel_name(src, pattern='// .globl')
ptxas, _ = path_to_ptxas()
> return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion)
E RuntimeError: Internal Triton PTX codegen error:
E ptxas /tmp/compile-ptx-src-99e631, line 104; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 104; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 106; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 106; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 108; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 108; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 110; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 110; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 112; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 112; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 114; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 114; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 116; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 116; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 213; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 213; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 248; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 248; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 283; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 283; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 318; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 318; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 353; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 353; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 388; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 388; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 423; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 423; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 458; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 458; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 493; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 493; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 528; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 528; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 563; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 563; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 598; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 598; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 633; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 633; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 668; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 668; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 703; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 703; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 738; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 738; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 773; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 773; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 808; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 808; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 843; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 843; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 878; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 878; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 913; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 913; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 948; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 948; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 983; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 983; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1018; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1018; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1052; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1052; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1407; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1407; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1409; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1409; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1411; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1411; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1413; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1413; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1415; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1415; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1417; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1417; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1419; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1419; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1421; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1421; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1423; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1423; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1425; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1425; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1427; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1427; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1429; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1429; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1431; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1431; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1433; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1433; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1435; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1435; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1437; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1437; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1439; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1439; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1441; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1441; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1443; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1443; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1445; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1445; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1447; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1447; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1449; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1449; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1451; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1451; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1453; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1453; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1455; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1455; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1457; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1457; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1459; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1459; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1461; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1461; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1463; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1463; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1465; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1465; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1467; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1467; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1469; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1469; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1511; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1511; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1513; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1513; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1515; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1515; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1517; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1517; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1519; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1519; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1521; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1521; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1523; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1523; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1525; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1525; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1527; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1527; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1529; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1529; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1531; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1531; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1533; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1533; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1535; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1535; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1537; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1537; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1539; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1539; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1541; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1541; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1543; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1543; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1545; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1545; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1547; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1547; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1549; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1549; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1551; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1551; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1553; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1553; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1555; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1555; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1557; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1557; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1559; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1559; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1561; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1561; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1563; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1563; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1565; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1565; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1567; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1567; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1569; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1569; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1571; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1571; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1573; error : Feature '.bf16' requires .target sm_80 or higher
E ptxas /tmp/compile-ptx-src-99e631, line 1573; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E ptxas fatal : Ptx assembly aborted due to errors
@austin362667 this is as expected! we should add pytest filtering based on gpu type. are you interested in the task?
Sure! would love to help. Opening an issue: https://github.com/linkedin/Liger-Kernel/issues/87. Feel free to add anything I may have missed.
I just attempted to train using axolotl on an instance with 8xMI300x, torch 2.4.0+ROCm6.1 and got this error. Not sure if anyone here has gotten Liger-Kernel to run on AMD?
Traceback (most recent call last):
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/axolotl/src/axolotl/cli/train.py", line 72, in <module>
fire.Fire(do_cli)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/workspace/axolotl/src/axolotl/cli/train.py", line 39, in do_cli
return do_train(parsed_cfg, parsed_cli_args)
File "/workspace/axolotl/src/axolotl/cli/train.py", line 67, in do_train
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
File "/workspace/axolotl/src/axolotl/train.py", line 188, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1948, in train
return inner_training_loop(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 2289, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 3328, in training_step
loss = self.compute_loss(model, inputs)
File "/workspace/axolotl/src/axolotl/core/trainer_builder.py", line 664, in compute_loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 3373, in compute_loss
outputs = model(**inputs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/liger_kernel/transformers/model/qwen2.py", line 81, in lce_forward
outputs = self.model(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 904, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/workspace/axolotl/src/axolotl/utils/gradient_checkpointing/__init__.py", line 10, in hf_grad_checkpoint_unsloth_wrapper
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 455, in decorate_fwd
return fwd(*args, **kwargs)
File "/workspace/axolotl/src/axolotl/utils/gradient_checkpointing/unsloth.py", line 32, in forward
output = forward_function(hidden_states, *args)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 669, in forward
hidden_states = self.mlp(hidden_states)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/liger_kernel/transformers/swiglu.py", line 21, in forward
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/liger_kernel/ops/utils.py", line 18, in wrapper
return fn(ctx, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/liger_kernel/ops/swiglu.py", line 77, in forward
_swiglu_forward_kernel[(n_rows,)](
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/triton/backends/amd/driver.py", line 418, in __call__
self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
What is the triton version? I used Triton 3.0.0 and successfully run all tests on AMD 7900 on Rocm 6.2. @DocShotgun
@DocShotgun Thanks for reporting the issue. Can you also share the minimum code to reproduce?
The Triton version is 3.0.0. I'm also running flash-attn 2.6.3 (built for gfx942 arch on torch 2.4.0+ROCm6.1), but I'm not sure if that's relevant.
Unfortunately I don't have a minimal code example, as I'm using a trainer called axolotl to do full weights finetune on Qwen2 72B. This trainer was able to successfully finetune a 34B llama-like (Yi-1.5-34B-32K) on 8xH100. I can share the configuration used if it would be helpful.
EDIT: Update - I installed this wheel of nightly pytorch triton rocm and managed to get a run started without the error on 1x MI300x. Will need to see if I can get it working for my full training run.
EDIT2: Got it to run on 8x MI300x, but with the pip version of liger-kernel. Getting some triton compilation errors trying to install from git.
EDIT3: Nevermind, unable to replicate my fully working setup haha.