Jax Error only when TPU-enabled runtime selected
I'm getting the following error, but only when I have a TPU in my runtime.
It works fine without a TPU or with a GPU hardware accelerator
RuntimeError Traceback (most recent call last)
7 frames /usr/local/lib/python3.10/dist-packages/jax/_src/lib/init.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version) 61 msg = (f'jaxlib is version {jaxlib_version}, but this version ' 62 f'of jax requires version >= {minimum_jaxlib_version}.') ---> 63 raise RuntimeError(msg) 64 65 if _jaxlib_version > _jax_version:
RuntimeError: jaxlib is version 0.3.25, but this version of jax requires version >= 0.4.19.
just update your jaxlib version
Hello,
Running the following in a cell should fix this issue:
!pip install --upgrade jaxlib
!pip install --upgrade chex
It appears the version of chex being pointed to is deprecated...
I hope this helps!