jax_verify
jax_verify copied to clipboard
Having Trouble Installing Dependencies
I created a fresh conda environments and pip install ... 'ed the requirements.txt only to realize this had not installed GPU-compatible jax so after a little searching I then installed jax using conda install jax cuda-nvcc -c conda-forge -c nvidia as recommended by this page but then I got the following warnings telling me that I was not using GPU, which I'd like to use:
(jax_verify) chelseas@server:~/jax_verify$ python3 examples/run_boundprop.py --boundprop_method=backward_crown_bound_propagation
I1228 22:37:51.228809 140315511407680 xla_bridge.py:170] Remote TPU is not linked into jax; skipping remote TPU.
I1228 22:37:51.228915 140315511407680 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I1228 22:37:51.228991 140315511407680 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1228 22:37:51.229050 140315511407680 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1228 22:37:51.229235 140315511407680 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I1228 22:37:51.229310 140315511407680 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W1228 22:37:51.229365 140315511407680 xla_bridge.py:362] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I opened an issue in the jax repo initially but I was thinking that someone maintaining this repo might also be able to help.
This seems to be a jax issue more than a jax_verify one.
Some things worth looking at:
- Is Cuda installed correctly on your machine? Do you have the same problem with other frameworks (like pytorch or Tensorflow) ? Can you find the
/usr/local/cuda-11.4file on your machine? - From your other issue, it seems that you are running Cuda 11.4. Have you made sure that you installed the correct version of jaxlib?
- Did you check that you activated the right conda environment?