graphcast icon indicating copy to clipboard operation
graphcast copied to clipboard

Jax Error only when TPU-enabled runtime selected

Open Prindle19 opened this issue 2 years ago • 2 comments

I'm getting the following error, but only when I have a TPU in my runtime.

It works fine without a TPU or with a GPU hardware accelerator

RuntimeError Traceback (most recent call last) in <cell line: 12>() 10 import cartopy.crs as ccrs 11 from google.cloud import storage ---> 12 from graphcast import autoregressive 13 from graphcast import casting 14 from graphcast import checkpoint

7 frames /usr/local/lib/python3.10/dist-packages/jax/_src/lib/init.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version) 61 msg = (f'jaxlib is version {jaxlib_version}, but this version ' 62 f'of jax requires version >= {minimum_jaxlib_version}.') ---> 63 raise RuntimeError(msg) 64 65 if _jaxlib_version > _jax_version:

RuntimeError: jaxlib is version 0.3.25, but this version of jax requires version >= 0.4.19.

Prindle19 avatar Jan 02 '24 21:01 Prindle19

just update your jaxlib version

ihecha avatar Jan 03 '24 03:01 ihecha

Hello,

Running the following in a cell should fix this issue: !pip install --upgrade jaxlib !pip install --upgrade chex

It appears the version of chex being pointed to is deprecated...

I hope this helps!

andrewlkd avatar Jan 09 '24 15:01 andrewlkd