ParallelFold icon indicating copy to clipboard operation
ParallelFold copied to clipboard

jaxlib.xla_extension.XlaRuntimeError

Open Xiaojun928 opened this issue 1 year ago • 0 comments

钟博好~

感谢开发出Parafold这一利器! 在安装Parafold过程中暂时没有遇到问题,并且顺利完成了第一步feature。我在尝试运行第二步结构预测时,遇到jax相关的问题,有劳您帮忙给一些建议呀~

GPU配置信息如下:

 NVIDIA-SMI 550.90.12              Driver Version: 550.90.12      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H800 PCIe               Off |   00000000:34:00.0 Off |                    0 |
| N/A   52C    P0             89W /  350W |       1MiB /  81559MiB |      3%      Default |
|                                         |                        |             Disabled

安装方式参考readme中“How to install” 部分,jax 的版本也是遵循readme中提到的0.3.25版本。 另外,我还参考 issue#39 中的建议安装了cuda-nvcc,但类似的问题并未得到解决。

我遇到的报错信息如下:

2024-11-21 11:37:59.402301: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:231] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0
2024-11-21 11:37:59.402323: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:234] Used ptxas at ptxas
2024-11-21 11:37:59.404084: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:628] failed to get PTX kernel "shift_right_logical" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2024-11-21 11:37:59.404116: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: INTERNAL: Could not find the corresponding function
Traceback (most recent call last):
  File "/home/software/ParallelFold/run_alphafold.py", line 491, in <module>
    app.run(main)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/software/ParallelFold/run_alphafold.py", line 464, in main
    predict_structure(
  File "/home/software/ParallelFold/run_alphafold.py", line 239, in predict_structure
    prediction_result = model_runner.predict(processed_feature_dict,
  File "/home/software/ParallelFold/alphafold/model/model.py", line 167, in predict
    result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/random.py", line 132, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 267, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 580, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 592, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 597, in random_seed_impl_base
    return seed(seeds)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/prng.py", line 832, in threefry_seed
    lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 515, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 115, in apply_primitive
    return compiled_fun(*args)
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 200, in <lambda>
    return lambda *args, **kw: compiled(*args, **kw)[0]
  File "/home/conda_envs/parafold/lib/python3.8/site-packages/jax/_src/dispatch.py", line 895, in _execute_compiled
    out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Could not find the corresponding function

这似乎是H800与JAX 0.3.25不兼容,请问如果升级JAX可以吗?

多谢!

Xiaojun928 avatar Nov 21 '24 05:11 Xiaojun928