jax_verify icon indicating copy to clipboard operation
jax_verify copied to clipboard

Having Trouble Installing Dependencies

Open chelseas opened this issue 3 years ago • 1 comments

I created a fresh conda environments and pip install ... 'ed the requirements.txt only to realize this had not installed GPU-compatible jax so after a little searching I then installed jax using conda install jax cuda-nvcc -c conda-forge -c nvidia as recommended by this page but then I got the following warnings telling me that I was not using GPU, which I'd like to use:

(jax_verify) chelseas@server:~/jax_verify$ python3 examples/run_boundprop.py --boundprop_method=backward_crown_bound_propagation                                                                                                  
I1228 22:37:51.228809 140315511407680 xla_bridge.py:170] Remote TPU is not linked into jax; skipping remote TPU.   
I1228 22:37:51.228915 140315511407680 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'                                                                                         
I1228 22:37:51.228991 140315511407680 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'                                                                     
I1228 22:37:51.229050 140315511407680 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'                                                                     
I1228 22:37:51.229235 140315511407680 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'                                                                          
I1228 22:37:51.229310 140315511407680 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W1228 22:37:51.229365 140315511407680 xla_bridge.py:362] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

I opened an issue in the jax repo initially but I was thinking that someone maintaining this repo might also be able to help.

chelseas avatar Dec 29 '22 06:12 chelseas

This seems to be a jax issue more than a jax_verify one.

Some things worth looking at:

  • Is Cuda installed correctly on your machine? Do you have the same problem with other frameworks (like pytorch or Tensorflow) ? Can you find the /usr/local/cuda-11.4 file on your machine?
  • From your other issue, it seems that you are running Cuda 11.4. Have you made sure that you installed the correct version of jaxlib?
  • Did you check that you activated the right conda environment?

bunelr avatar Jan 03 '23 09:01 bunelr