openpi icon indicating copy to clipboard operation
openpi copied to clipboard

LIBERO inference

Open Jerryisqx opened this issue 10 months ago • 2 comments

Hi Team, I'm trying to make inference and evaluation on LIBERO dataset. However, during loading the checkpoing of pi-0-fast-libero, I met the cuda_dnn error:

username@my_host_machine$ uv run scripts/serve_policy.py --env LIBERO
      Built draccus @ git+https://github.com/dlwh/draccus.git@9b690730ca108930519f48cc5dead72a72fd27cb
Uninstalled 1 package in 6.23s
Installed 1 package in 9.61s
INFO:root:Loading model...
E0330 18:23:43.985579   78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
E0330 18:23:43.985789   78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
INFO:2025-03-30 18:23:43,986:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-03-30 18:23:43,987:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:absl:orbax-checkpoint version: 0.11.1
INFO:absl:Created BasePyTreeCheckpointHandler: pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=None
INFO:absl:Restoring checkpoint from /scratch_net/biwidl313_second/chenqing/openpi/openpi-assets/checkpoints/pi0_fast_libero/params.
INFO:absl:[thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.
INFO:absl:[process=0] /jax/checkpoint/read/bytes_per_sec: 327.5 MiB/s (total bytes: 5.4 GiB) (time elapsed: 17 seconds) (per-host)
INFO:absl:Finished restoring checkpoint from /scratch_net/biwidl313_second/chenqing/openpi/openpi-assets/checkpoints/pi0_fast_libero/params.
INFO:absl:[process=0][thread=MainThread] Skipping global process sync, barrier name: Checkpointer:restore
E0330 18:24:01.090760   78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
E0330 18:24:01.091413   78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
Traceback (most recent call last):
  File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 123, in <module>
    main(tyro.cli(Args))
  File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 101, in main
    policy = create_policy(args)
             ^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 97, in create_policy
    return create_default_policy(args.env, default_prompt=args.default_prompt)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 83, in create_default_policy
    return _policy_config.create_trained_policy(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/src/openpi/policies/policy_config.py", line 56, in create_trained_policy
    model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/src/openpi/models/model.py", line 228, in load
    model = nnx.eval_shape(self.create, jax.random.key(0))
                                        ^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 218, in key
    return _key('key', seed, impl)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 198, in _key
    return prng.random_seed(seed, impl=impl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 529, in random_seed
    seeds_arr = jnp.asarray(np.int64(seeds))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5820, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5653, in array
    out_array: Array = lax_internal._convert_element_type(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 612, in _convert_element_type
    return convert_element_type_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 463, in bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 3254, in _convert_element_type_bind_with_trace
    operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 954, in process_primitive
    return primitive.impl(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py", line 89, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 340, in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 198, in _python_pjit_helper
    out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1669, in _pjit_call_impl_python
    ).compile()
      ^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2419, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2922, in from_hlo
    xla_executable = _cached_compilation(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2723, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 464, in compile_or_get_cached
    return _compile_and_write_cache(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 665, in _compile_and_write_cache
    executable = backend_compile(
                 ^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 321, in backend_compile
    raise e
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 315, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

I install the depency correctly and didn't not meet any problem. I'm using CUDA11.4 with driver version 470.82.01 on cluster.

Jerryisqx avatar Mar 29 '25 18:03 Jerryisqx

See cuurent

nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.256.02   Driver Version: 470.256.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA TITAN X ...  On   | 00000000:04:00.0 Off |                  N/A |
| 23%   27C    P8     8W / 250W |      1MiB / 12196MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

nvcc-V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_Oct_11_21:27:02_PDT_2021
Cuda compilation tools, release 11.4, V11.4.152
Build cuda_11.4.r11.4/compiler.30521435_0

Jerryisqx avatar Mar 30 '25 16:03 Jerryisqx

Have you tried running with Docker? If it works, JAX is likely picking up the wrong cuda dependencies.

uzhilinsky avatar Apr 10 '25 01:04 uzhilinsky