Skye Wanderman-Milne
Skye Wanderman-Milne
Hi @aizzaac , let's discuss in #4487 instead of this thread.
Can you try setting the env var `XLA_PYTHON_CLIENT_PREALLOCATE=false`? See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for more information and similar options. I think the issue might be that the first jax process is pre-allocating most...
In your example, you're passing `devices` to `jit`, not `device` (no 's'). Passing `devices` should be an error. You need to pass `device=jax.devices()[1]` to use the second GPU.
This might be an issue with jaxlib+cuda112 wheel (I admit I only tested the cuda 11.0 version!). I can try it out later, but to any passersby on cuda 11.2,...
I just tried your example with CUDA 11.2 and jaxlib 0.1.62 and it works for me. What kind of GPU do you have? I'll also ask the XLA:GPU team if...
Hi, sorry for the delay. Is this still an issue for anyone?
For my understanding, what other GPU libraries would you like to run with JAX? One way I can think to implement this would be to provide a way to pass...
If you just import jax and not TF, does jax see the GPUs?
Can you tell which line of code is causing the long compile? For the GPU issue, how are you installing jax+ jaxib?
Glad to hear you resolved the jaxlib issue! Unfortunately we cannot support GPUs by default since we need different jaxlibs for different CUDA versions. https://github.com/google/jax/pull/4065 should make this a little...