Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Which GPUs does this work on?

Open nbroad1881 opened this issue 1 year ago • 10 comments

I'm assuming it only works on Ampere, Hopper, Lovelace. Is that correct? It might be nice to specify in the readme, if it is limited to certain GPU types.

nbroad1881 avatar Aug 23 '24 01:08 nbroad1881

Thank you @nbroad1881 for raising the issue! Since our kernels are completely triton based, they share the same compatibility as triton itself, which is the following as https://github.com/triton-lang/triton says.

Supported Platforms:

Linux
Supported Hardware:

NVIDIA GPUs (Compute Capability 7.0+)
AMD GPUs (ROCm 5.2+)
Under development: CPUs

We have only tested on Ampere and Hopper for now. We hope the community can help us test on broader spectrum in production. I will mark this as a TODO and gather the community help. Thanks!

ByronHsu avatar Aug 23 '24 02:08 ByronHsu

Thanks!

nbroad1881 avatar Aug 23 '24 21:08 nbroad1881

@nbroad1881 Tested on RTX 3070 before and kernels work as expected, see https://github.com/linkedin/Liger-Kernel/pull/47. However, some of the large test/benchmarks may fail due to out of memory.

lancerts avatar Aug 23 '24 21:08 lancerts

@ByronHsu Hi, thanks for the super amazing work!

I failed testing on T4 sm_75 (Turing), which does not support certain features like .bf16 (bfloat16) operations. These features require a GPU architecture of sm_80 (Ampere) or higher. Please check the following log for details. Let me know if I misunderstood anything. Appreciate it!

When I ran $ pytest ./test/transformers/test_geglu.py::test_correctness, it returned:

======================================= test session starts ========================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /content/Liger-Kernel
plugins: typeguard-4.3.0, anyio-3.7.1
collected 8 items                                                                                  

test/transformers/test_geglu.py ....FFFF                                                     [100%]

============================================= FAILURES =============================================
_____________________ test_correctness[dtype1-10000.0-0.006-2-2048-4096-11008] _____________________

bsz = 2, seq_len = 2048, hidden_size = 4096, intermediate_size = 11008, dtype = torch.bfloat16
atol = 10000.0, rtol = 0.006

    @pytest.mark.parametrize(
        "bsz, seq_len, hidden_size, intermediate_size",
        [
            (2, 2048, 4096, 11008),
            (2, 2048, 2048, 4096),
            # weird shapes
            (9, 41, 341, 4231),
            (6, 42, 256, 2048),
        ],
    )
    @pytest.mark.parametrize(
        "dtype, atol, rtol",
        [
            # atol is for small values: they have more difference, so set atol higher
            # rtol is for larger values: they are very close, so set rtol lower
            (torch.float32, 1e-0, 2e-6),
            (torch.bfloat16, 1e4, 6e-3),
        ],
    )
    def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol):
    
        _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype)
    
        x1 = _input.clone().requires_grad_(True)
        x2 = _input.clone().requires_grad_(True)
    
        # initialize weights
        G = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype)
        U = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype)
        D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype)
    
        llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to("cuda").to(dtype)
        llama_mlp.gate_proj.weight.data = G.T
        llama_mlp.up_proj.weight.data = U.T
        llama_mlp.down_proj.weight.data = D.T
    
        liger_mlp = LigerGEGLUMLP(config=LLAMA_CONFIG).to("cuda").to(dtype)
        liger_mlp.gate_proj.weight.data = G.T
        liger_mlp.up_proj.weight.data = U.T
        liger_mlp.down_proj.weight.data = D.T
    
        y1 = llama_mlp(x1)
>       y2 = liger_mlp(x2)

test/transformers/test_geglu.py:58: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
src/liger_kernel/transformers/geglu.py:22: in forward
    LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:598: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
src/liger_kernel/ops/utils.py:18: in wrapper
    return fn(ctx, *args, **kwargs)
src/liger_kernel/ops/geglu.py:104: in forward
    _geglu_tanh_forward_kernel[(n_rows,)](
/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:167: in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py:416: in run
    self.cache[device][key] = compile(
/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py:193: in compile
    next_module = compile_ir(module, metadata)
/usr/local/lib/python3.10/dist-packages/triton/compiler/backends/cuda.py:201: in <lambda>
    stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

src = '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.2\n.target sm_75\n.address_size 64\n\n\t// .globl\t_geglu_t...pes_start0:\n.b8 2\n.b8 0\n.b32 .debug_info\n.b32 191\n.b32 0\n$L__pubTypes_end0:\n\t}\n\t.section\t.debug_loc\t{\t}\n'
metadata = {'AMDGCN_ENABLE_DUMP': False, 'DISABLE_FAST_REDUCTION': False, 'DISABLE_MMA_V3': False, 'ENABLE_TMA': False, ...}
opt = CUDAOptions(num_warps=16, num_ctas=1, num_stages=3, cluster_dims=(1, 1, 1), ptx_version=None, enable_warp_specializati...logue=False, enable_fp_fusion=True, allow_fp8e4nv=False, max_num_imprecise_acc_default=0, extern_libs=None, debug=None)
capability = 75

    @staticmethod
    def make_cubin(src, metadata, opt, capability):
        metadata["name"] = get_kernel_name(src, pattern='// .globl')
        ptxas, _ = path_to_ptxas()
>       return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion)
E       RuntimeError: Internal Triton PTX codegen error: 
E       ptxas /tmp/compile-ptx-src-99e631, line 104; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 104; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 106; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 106; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 108; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 108; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 110; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 110; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 112; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 112; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 114; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 114; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 116; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 116; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 213; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 213; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 248; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 248; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 283; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 283; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 318; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 318; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 353; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 353; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 388; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 388; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 423; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 423; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 458; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 458; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 493; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 493; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 528; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 528; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 563; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 563; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 598; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 598; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 633; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 633; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 668; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 668; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 703; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 703; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 738; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 738; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 773; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 773; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 808; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 808; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 843; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 843; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 878; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 878; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 913; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 913; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 948; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 948; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 983; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 983; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1018; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1018; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1052; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1052; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1407; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1407; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1409; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1409; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1411; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1411; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1413; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1413; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1415; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1415; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1417; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1417; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1419; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1419; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1421; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1421; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1423; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1423; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1425; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1425; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1427; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1427; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1429; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1429; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1431; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1431; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1433; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1433; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1435; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1435; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1437; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1437; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1439; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1439; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1441; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1441; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1443; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1443; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1445; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1445; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1447; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1447; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1449; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1449; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1451; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1451; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1453; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1453; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1455; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1455; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1457; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1457; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1459; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1459; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1461; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1461; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1463; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1463; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1465; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1465; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1467; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1467; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1469; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1469; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1511; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1511; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1513; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1513; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1515; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1515; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1517; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1517; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1519; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1519; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1521; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1521; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1523; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1523; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1525; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1525; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1527; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1527; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1529; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1529; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1531; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1531; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1533; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1533; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1535; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1535; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1537; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1537; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1539; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1539; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1541; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1541; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1543; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1543; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1545; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1545; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1547; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1547; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1549; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1549; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1551; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1551; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1553; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1553; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1555; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1555; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1557; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1557; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1559; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1559; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1561; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1561; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1563; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1563; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1565; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1565; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1567; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1567; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1569; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1569; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1571; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1571; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1573; error   : Feature '.bf16' requires .target sm_80 or higher
E       ptxas /tmp/compile-ptx-src-99e631, line 1573; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
E       ptxas fatal   : Ptx assembly aborted due to errors

austin362667 avatar Aug 25 '24 03:08 austin362667

@austin362667 this is as expected! we should add pytest filtering based on gpu type. are you interested in the task?

ByronHsu avatar Aug 25 '24 04:08 ByronHsu

Sure! would love to help. Opening an issue: https://github.com/linkedin/Liger-Kernel/issues/87. Feel free to add anything I may have missed.

austin362667 avatar Aug 26 '24 01:08 austin362667

I just attempted to train using axolotl on an instance with 8xMI300x, torch 2.4.0+ROCm6.1 and got this error. Not sure if anyone here has gotten Liger-Kernel to run on AMD?

Traceback (most recent call last):
  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/envs/py_3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 72, in <module>
    fire.Fire(do_cli)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 39, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/workspace/axolotl/src/axolotl/cli/train.py", line 67, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/workspace/axolotl/src/axolotl/train.py", line 188, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1948, in train
    return inner_training_loop(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 2289, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 3328, in training_step
    loss = self.compute_loss(model, inputs)
  File "/workspace/axolotl/src/axolotl/core/trainer_builder.py", line 664, in compute_loss
    return super().compute_loss(model, inputs, return_outputs=return_outputs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 3373, in compute_loss
    outputs = model(**inputs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1846, in forward
    loss = self.module(*inputs, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/liger_kernel/transformers/model/qwen2.py", line 81, in lce_forward
    outputs = self.model(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 904, in forward
    layer_outputs = self._gradient_checkpointing_func(
  File "/workspace/axolotl/src/axolotl/utils/gradient_checkpointing/__init__.py", line 10, in hf_grad_checkpoint_unsloth_wrapper
    return Unsloth_Offloaded_Gradient_Checkpointer.apply(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 455, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/workspace/axolotl/src/axolotl/utils/gradient_checkpointing/unsloth.py", line 32, in forward
    output = forward_function(hidden_states, *args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 669, in forward
    hidden_states = self.mlp(hidden_states)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/liger_kernel/transformers/swiglu.py", line 21, in forward
    LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/liger_kernel/ops/utils.py", line 18, in wrapper
    return fn(ctx, *args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/liger_kernel/ops/swiglu.py", line 77, in forward
    _swiglu_forward_kernel[(n_rows,)](
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/triton/backends/amd/driver.py", line 418, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]:  Code: 1, Messsage: invalid argument

DocShotgun avatar Aug 27 '24 22:08 DocShotgun

What is the triton version? I used Triton 3.0.0 and successfully run all tests on AMD 7900 on Rocm 6.2. @DocShotgun

helloworld1 avatar Aug 27 '24 22:08 helloworld1

@DocShotgun Thanks for reporting the issue. Can you also share the minimum code to reproduce?

lancerts avatar Aug 27 '24 22:08 lancerts

The Triton version is 3.0.0. I'm also running flash-attn 2.6.3 (built for gfx942 arch on torch 2.4.0+ROCm6.1), but I'm not sure if that's relevant.

Unfortunately I don't have a minimal code example, as I'm using a trainer called axolotl to do full weights finetune on Qwen2 72B. This trainer was able to successfully finetune a 34B llama-like (Yi-1.5-34B-32K) on 8xH100. I can share the configuration used if it would be helpful.

EDIT: Update - I installed this wheel of nightly pytorch triton rocm and managed to get a run started without the error on 1x MI300x. Will need to see if I can get it working for my full training run.

EDIT2: Got it to run on 8x MI300x, but with the pip version of liger-kernel. Getting some triton compilation errors trying to install from git.

EDIT3: Nevermind, unable to replicate my fully working setup haha.

DocShotgun avatar Aug 27 '24 22:08 DocShotgun