openpi
openpi copied to clipboard
LIBERO inference
Hi Team, I'm trying to make inference and evaluation on LIBERO dataset. However, during loading the checkpoing of pi-0-fast-libero, I met the cuda_dnn error:
username@my_host_machine$ uv run scripts/serve_policy.py --env LIBERO
Built draccus @ git+https://github.com/dlwh/draccus.git@9b690730ca108930519f48cc5dead72a72fd27cb
Uninstalled 1 package in 6.23s
Installed 1 package in 9.61s
INFO:root:Loading model...
E0330 18:23:43.985579 78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
E0330 18:23:43.985789 78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
INFO:2025-03-30 18:23:43,986:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-03-30 18:23:43,987:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:absl:orbax-checkpoint version: 0.11.1
INFO:absl:Created BasePyTreeCheckpointHandler: pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=None
INFO:absl:Restoring checkpoint from /scratch_net/biwidl313_second/chenqing/openpi/openpi-assets/checkpoints/pi0_fast_libero/params.
INFO:absl:[thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.
INFO:absl:[process=0] /jax/checkpoint/read/bytes_per_sec: 327.5 MiB/s (total bytes: 5.4 GiB) (time elapsed: 17 seconds) (per-host)
INFO:absl:Finished restoring checkpoint from /scratch_net/biwidl313_second/chenqing/openpi/openpi-assets/checkpoints/pi0_fast_libero/params.
INFO:absl:[process=0][thread=MainThread] Skipping global process sync, barrier name: Checkpointer:restore
E0330 18:24:01.090760 78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
E0330 18:24:01.091413 78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
Traceback (most recent call last):
File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 123, in <module>
main(tyro.cli(Args))
File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 101, in main
policy = create_policy(args)
^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 97, in create_policy
return create_default_policy(args.env, default_prompt=args.default_prompt)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 83, in create_default_policy
return _policy_config.create_trained_policy(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/src/openpi/policies/policy_config.py", line 56, in create_trained_policy
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/src/openpi/models/model.py", line 228, in load
model = nnx.eval_shape(self.create, jax.random.key(0))
^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 218, in key
return _key('key', seed, impl)
^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 198, in _key
return prng.random_seed(seed, impl=impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 529, in random_seed
seeds_arr = jnp.asarray(np.int64(seeds))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5820, in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5653, in array
out_array: Array = lax_internal._convert_element_type(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 612, in _convert_element_type
return convert_element_type_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 463, in bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 3254, in _convert_element_type_bind_with_trace
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 468, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 954, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py", line 89, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 340, in cache_miss
pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 198, in _python_pjit_helper
out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1669, in _pjit_call_impl_python
).compile()
^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2419, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2922, in from_hlo
xla_executable = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2723, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 464, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 665, in _compile_and_write_cache
executable = backend_compile(
^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 321, in backend_compile
raise e
File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 315, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
I install the depency correctly and didn't not meet any problem. I'm using CUDA11.4 with driver version 470.82.01 on cluster.
See cuurent
nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.256.02 Driver Version: 470.256.02 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA TITAN X ... On | 00000000:04:00.0 Off | N/A |
| 23% 27C P8 8W / 250W | 1MiB / 12196MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
nvcc-V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_Oct_11_21:27:02_PDT_2021
Cuda compilation tools, release 11.4, V11.4.152
Build cuda_11.4.r11.4/compiler.30521435_0
Have you tried running with Docker? If it works, JAX is likely picking up the wrong cuda dependencies.