Warn if a GPU is detected but it is too old to be supported by XLA
Dear jax team,
I'm struggling with installing jax with GPU support. I'm running Ubuntu 18.04 with CUDA 10.0 and CUDNN 7.6.1. The GPU is a Quadro K4000 with 410.48 drivers (manually installed, no conda).
I tried installing jax from pip and from the repo. For jaxlib, I tried the pip wheels (according to the guide in the readme.md) and compiling from source with
python3 build/build.py --enable_cuda --cuda_path /usr/local/cuda-10.0/ --cudnn_path /usr/local/cuda-10.0/ --enable_march_native
all with no success: jax keeps falling back to CPU. I tried all combinations both with --user installs and global sudo installs. Also with reboots in between.
I made sure I was using the correct jax / jaxlib install every time with jax.__file__ and jaxlib.__file__. I'm out of ideas now. Is there a known problem with Quadro Cards? Could you point me in a direction I have not looked for errors? Thank you very much!
I've noticed the K4000 GPU is pretty old and only support CUDA compute capability 3.0. Could this be a problem? I noticed no errors in the compilation of jaxlib. If jax needs CUDA compute capability > 3.0 (like tensorflow apparently) there should be an according message when falling back to CPU.
Yes, I think your GPU is too old. JAX uses XLA for its GPU support, which has the same GPU requirements as TensorFlow: https://www.tensorflow.org/install/gpu#hardware_requirements
I agree that a helpful warning might be a good thing.
Thanks for the info @hawkinsp. Not great for me right now but might save others some time if the future :+1: