TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

RuntimeError with Assertion failed: driver_result == cudaDriverEntryPointSuccess.

Open Salieri0515 opened this issue 5 months ago • 4 comments

Describe the bug

when running examples/llama/train_llama3_8b_fp8.sh and building transformer layer of GPTModel:

RuntimeError: /TransformerEngine/transformer_engine/common/util/cuda_driver.cpp:42 in function get_symbol: Assertion failed: driver_result == cudaDriverEntryPointSuccess. Could not find CUDA driver entry point for cuCtxGetCurrent when instantiating TERowParallelLinear when instantiating SelfAttentio

Steps/Code to reproduce bug

  1. conda create -n megatronlm python=3.12
  2. pip install torch==2.6.0
  3. pip install megatron-core
  4. pip install --no-build-isolation transformer-engine[pytorch]
  5. Pip install regex six PyYAML psutil pybind11

then just run examples/llama/train_llama3_8b_fp8.sh

Environment overview (please complete the following information) ubuntu 22.04 python 3.12 torch 2.6.0+cu124 cuda 12.4(by nvcc) cudnn 9.1.0 tranformer-engine 2.8.0

Device details H100

Additional context

Add any other context about the problem here.

Salieri0515 avatar Oct 10 '25 07:10 Salieri0515

I wonder if it's an issue with the CUDA Driver? I see you've compiled with CUDA 12.4, which requires CUDA Driver 525+ (see CUDA 12.4 release notes). One quick way to check is by running nvidia-smi.

Thinking through the problem

The failure is happening when we try to access a CUDA Driver function: https://github.com/NVIDIA/TransformerEngine/blob/dd9433e7ad28c12f27da9770be54c9c584e85fa0/transformer_engine/common/util/cuda_driver.cpp#L22-L45

We do it in this indirect way because the CUDA Driver may be different at compile-time and run-time (see https://github.com/NVIDIA/TransformerEngine/pull/1240).

If you've built TE with CUDA 12.4, then cudaGetDriverEntryPointByVersion is not supported and it is using cudaGetDriverEntryPoint (see CUDA 12.4 docs).

timmoon10 avatar Oct 10 '25 22:10 timmoon10

I wonder if it's an issue with the CUDA Driver? I see you've compiled with CUDA 12.4, which requires CUDA Driver 525+ (see CUDA 12.4 release notes). One quick way to check is by running nvidia-smi.

Thinking through the problem The failure is happening when we try to access a CUDA Driver function:

TransformerEngine/transformer_engine/common/util/cuda_driver.cpp

Lines 22 to 45 in dd9433e

void *get_symbol(const char *symbol, int cuda_version) { constexpr char driver_entrypoint[] = "cudaGetDriverEntryPoint"; constexpr char driver_entrypoint_versioned[] = "cudaGetDriverEntryPointByVersion"; // We link to the libcudart.so already, so can search for it in the current context static GetEntryPoint driver_entrypoint_fun = reinterpret_cast<GetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint)); static VersionedGetEntryPoint driver_entrypoint_versioned_fun = reinterpret_cast<VersionedGetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint_versioned));

cudaDriverEntryPointQueryResult driver_result; void *entry_point = nullptr; if (driver_entrypoint_versioned_fun != nullptr) { // Found versioned entrypoint function NVTE_CHECK_CUDA(driver_entrypoint_versioned_fun(symbol, &entry_point, cuda_version, cudaEnableDefault, &driver_result)); } else { NVTE_CHECK(driver_entrypoint_fun != nullptr, "Error finding the CUDA Runtime-Driver interop."); // Versioned entrypoint function not found NVTE_CHECK_CUDA(driver_entrypoint_fun(symbol, &entry_point, cudaEnableDefault, &driver_result)); } NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess, "Could not find CUDA driver entry point for ", symbol); return entry_point; } We do it in this indirect way because the CUDA Driver may be different at compile-time and run-time (see #1240).

If you've built TE with CUDA 12.4, then cudaGetDriverEntryPointByVersion is not supported and it is using cudaGetDriverEntryPoint (see CUDA 12.4 docs).

@timmoon10 Thanks for your quick reply. I checked nvidia-smi and it shows my driver version is 560.35.03 and CUDA version 12.6(H800 GPU, sorry for the mistake above), which I think is compatible with CUDA 12.4. However, my environment is built from docker with CUDA 12.1(/usr/local/cuda-12.1), and I installed CUDA 12.4 manually(CUDA_HOME been set properly). Should I install CUDA driver(from cuda .run file) for 12.4 as well?

Salieri0515 avatar Oct 13 '25 03:10 Salieri0515

The CUDA driver and runtime are both backward compatible, so your config looks reasonable to me. We're currently using CUDA 13.0 in most of our use-cases, but CUDA 12.4 should be supported (see build requirements).

I'm a little confused what's causing the bug for you. I assume that plain PyTorch is working correctly for you? That has similar CUDA driver logic when it launches kernels, so if that is working then that means there is some bug in TE.

timmoon10 avatar Oct 14 '25 00:10 timmoon10

The CUDA driver and runtime are both backward compatible, so your config looks reasonable to me. We're currently using CUDA 13.0 in most of our use-cases, but CUDA 12.4 should be supported (see build requirements).

I'm a little confused what's causing the bug for you. I assume that plain PyTorch is working correctly for you? That has similar CUDA driver logic when it launches kernels, so if that is working then that means there is some bug in TE.

@timmoon10 I checked whether pytorch worked for me:

import torch

print(f"CUDA version: {torch.version.cuda}")
print(f"PyTorch version: {torch.__version__}")

def test_pytorch_cuda():
    if not torch.cuda.is_available():
        print("CUDA not available")
        return False
    
    try:
        a = torch.randn(100, 100, device='cuda')
        b = torch.randn(100, 100, device='cuda')
        c = a @ b
        

        for dtype in [torch.float32, torch.float16, torch.bfloat16]:
            if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
                continue
            a = torch.randn(10, 10, dtype, device='cuda')
            b = torch.randn(10, 10, dtype=dtype, device='cuda')
            c = a @ b
            print(c)

        for i in range(10):
            x = torch.randn(1000, 1000, device='cuda')
            y = torch.randn(1000, 1000, device='cuda')
            z = torch.mm(x, y)
        
        print(z)
        return True
        
    except Exception as e:
        print(f"PyTorch CUDA test failed: {e}")
        return False

test_pytorch_cuda()

It can run correctly with outputs "CUDA version: 12.4\n PyTorch version: 2.6.0+cu124" and corresponding tensor values. However, when running https://github.com/NVIDIA/Megatron-LM/examples/llama/train_llama3_8b_fp8.sh, still, the bug exists. Here is a more complete log:

[rank0]:   File "/root/develop/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 387, in __init__
[rank0]:     super().__init__(
[rank0]:   File "/data/users/jiangzm/envs/megatronlm/lib/python3.12/site-packages/transformer_engine/pytorch/module/linear.py", line 1203, in __init__
[rank0]:     self.reset_parameters(defer_init=device == "meta")
[rank0]:   File "/data/users/jiangzm/envs/megatronlm/lib/python3.12/site-packages/transformer_engine/pytorch/module/linear.py", line 1225, in reset_parameters
[rank0]:     super().reset_parameters(defer_init=defer_init)
[rank0]:   File "/data/users/jiangzm/envs/megatronlm/lib/python3.12/site-packages/transformer_engine/pytorch/module/base.py", line 1203, in reset_parameters
[rank0]:     param = quantizer(param)
[rank0]:             ^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/jiangzm/envs/megatronlm/lib/python3.12/site-packages/transformer_engine/pytorch/tensor/quantized_tensor.py", line 214, in __call__
[rank0]:     return self.quantize(tensor)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/jiangzm/envs/megatronlm/lib/python3.12/site-packages/transformer_engine/pytorch/tensor/quantized_tensor.py", line 202, in quantize
[rank0]:     return _QuantizeFunc.apply(tensor, self)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/jiangzm/envs/megatronlm/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/jiangzm/envs/megatronlm/lib/python3.12/site-packages/transformer_engine/pytorch/tensor/quantized_tensor.py", line 268, in forward
[rank0]:     return tex.quantize(tensor, quantizer)
[rank0]:            ^^^^^^^^^^^^
[rank0]: RuntimeError: /TransformerEngine/transformer_engine/common/util/cuda_driver.cpp:42 in function get_symbol: Assertion failed: driver_result == cudaDriverEntryPointSuccess. Could not find CUDA driver entry point for cuCtxGetCurrent when instantiating TERowParallelLinear when instantiating SelfAttention when instantiating TransformerLayer

I tried for TE==2.5.0/2.6.0post1/2.8.0 and CUDA 12.1/12.4(with corresponding torch versions), the bug exists. Currently I cannot use CUDA 13.0. I will try a new environment with CUDA 12.4 installed (not manually) and CUDA 12.6 (manually) later.

Salieri0515 avatar Oct 14 '25 01:10 Salieri0515