jax
jax copied to clipboard
JAX reports inaccurate error when trying to acquire already-owned TPU
Description
JAX doesn't nicely share TPUs with other frameworks (e.g. PyTorch/XLA, TF, etc). This is fine, but the error reported by JAX is misleading. It'd be preferable to actively check if JAX is unable to use the TPU due to multiple libraries using the TPU and report the issue to the user.
Here's an example:
import torch
import torch_xla.core.xla_model as xm
t = torch.tensor([1.0, 2.0, 3.0], device=xm.xla_device()) # PyTorch/XLA acquires the TPU here.
import jax
print(jax.devices())
jax.devices() results in the following error:
RuntimeError: Unable to initialize backend 'tpu': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.23) and framework PJRT API version 0.40). (set JAX_PLATFORMS='' to automatically choose an available backend)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.23
jaxlib: 0.4.23
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1