jax icon indicating copy to clipboard operation
jax copied to clipboard

JAX reports inaccurate error when trying to acquire already-owned TPU

Open sagelywizard opened this issue 1 year ago • 0 comments

Description

JAX doesn't nicely share TPUs with other frameworks (e.g. PyTorch/XLA, TF, etc). This is fine, but the error reported by JAX is misleading. It'd be preferable to actively check if JAX is unable to use the TPU due to multiple libraries using the TPU and report the issue to the user.

Here's an example:

import torch
import torch_xla.core.xla_model as xm
t = torch.tensor([1.0, 2.0, 3.0], device=xm.xla_device()) # PyTorch/XLA acquires the TPU here.
import jax
print(jax.devices())

jax.devices() results in the following error:

RuntimeError: Unable to initialize backend 'tpu': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.23) and framework PJRT API version 0.40). (set JAX_PLATFORMS='' to automatically choose an available backend)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.23
jaxlib: 0.4.23
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1

sagelywizard avatar Mar 08 '24 23:03 sagelywizard