Tests failing due to incompatibility between TFP and JAX 0.7
Our tests are failing because of a compatibility issue with TFP and the newest version of JAX. See https://github.com/tensorflow/probability/issues/2009.
Thanks for such a comment of shared pain... Indeed, it is several days that I am struggling to run the two jupyter notebooks of lgssm_parallel_inference.ipynb and lgssm_hmc.ipynb. The culprit in the parallel notebook is the following line: from dynamax.linear_gaussian_ssm import lgssm_smoother, parallel_lgssm_smoother. The error message is: AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.
I would like to implement a parallelized Kalman filter with MCMC inference for predicting physiological time-series on a short-term horizon. I have tried with tfp.sts.fit_with_hmc in Tensorflow Probability, but it does not work unfortunately (the variational inference version does work).
I have tried with the combination of Python packages suggested in here, but it still throws an error, though.
Could you kindly share a list of Python packages (and their specific versions) which works smoothly with "dynamax"? We would like a bullet-proof requirements.txt. That would be a great workaround for now.
I attach mine (the one which throws an error) below.
I am replying to my own comment above. I have taken inspiration from a requirements.txt found here.
After some trial-and-error, I have found a combination of Python packages which is WORKING with the previously mentioned notebooks from dynamax. Unfortunately it runs only on CPUs... Anyhow, better than nothing! I have uploaded the WORKING requirements.txt below since it might be useful for someone facing the same problem.
I am using an Ubuntu 24.04.2 LTS operating system on a ASUS ROG Strix with x86_64 architecture and Python 3.12.3.
Also had this problem. Was able to temporarily get around it by installing tfp-nightly (as detailed in https://github.com/tensorflow/probability/issues/2009)
Thanks for the feedback. I do confirm that installing nightly builds (wheels) both for TensorFlow and TensorFlow Probability works with GPUs, as well !!!
For newbies, you can download the nightly builds wheels from repositories here and here for TF and TFP, respectively. Please choose them according to your own Operating System and Python version.
As follows, the steps I have used on Linux terminal (Ubuntu 24.04.2) with Python 3.12:
# create virtual environment called "benchmarking"
python3 -m venv benchmarking
# enter the virtual environment
source benchmarking/bin/activate
# install dynamax, matplotlib, pandas and numpy
pip install dynamax matplotlib pandas numpy
# uninstall versions installed by default by dynamax
pip uninstall tensorflow-probability jax jaxlib
# re-install both TF and TFP wheels from nightly builds repositories
pip install "/home/paolo/miscellaneous/downloads/tf_nightly-2.21.0.dev20250801-cp312-cp312-manylinux_2_27_x86_64.whl"
pip install "/home/paolo/miscellaneous/downloads/tfp_nightly-0.26.0.dev20250808-py2.py3-none-any.whl"
# re-install JAX with GPU support
pip install -U "jax[cuda12]"
I would like provide the skeleton for a .sh file with the newest wheels. Unfortunately, I could not upload the .sh file for unknown reasons... Anyway, create a .sh file with the following commands and just run it on the terminal for Ubuntu 24.04. It should do everything for you.
## script for installing "dynamax" on Ubuntu 24.04.2 LTS operating system on a ASUS ROG Strix with x86_64 architecture and Python 3.12.3
# create virtual environment called "dynamax_gpu"
python3 -m venv dynamax_gpu
# enter the virtual environment
source dynamax_gpu/bin/activate
# install dynamax, matplotlib, pandas and numpy
pip install dynamax matplotlib pandas numpy
# uninstall versions installed by default by dynamax
pip uninstall -y tensorflow-probability jax jaxlib
# download TF and TFP wheels from nightly builds repositories (WARNING: it install by default to "home" directory)
wget https://files.pythonhosted.org/packages/89/90/fd68bd8d2eb2c0aa3a60b82a8ff26556f3d19815019ff25aa266325f7f6c/tf_nightly-2.21.0.dev20250819-cp312-cp312-manylinux_2_27_x86_64.whl
wget https://files.pythonhosted.org/packages/0c/b1/492b60fbed49141ce97c8b22e6d1ca9457b6593b25056537b0a71fe6890a/tfp_nightly-0.26.0.dev20250825-py2.py3-none-any.whl
# re-install both TF and TFP nightly wheels
pip install "tf_nightly-2.21.0.dev20250819-cp312-cp312-manylinux_2_27_x86_64.whl"
pip install "tfp_nightly-0.26.0.dev20250825-py2.py3-none-any.whl"
# re-install JAX with GPU support
pip install -U "jax[cuda12]"
## if you need just CPU support for JAX, comment the line above ed uncomment the line below
# pip install -U "jax"
The script above . If you want just the CPU version of JAX, delete [cuda12] from the last command pip install -U "jax[cuda12]".
I also added a "requirements_gpu.txt" in case somebody would prefer building the GPU-compatible environment via pip install -r requirements_gpu.txt.
Environment tested on the following tutorials:
To add to the helpful workaround suggestions by @PaoloRanzi81 and @bantin above, maybe the simplest fix is to just downgrade jax and chex. After installing Dynamax, I ran:
pip install "jax<0.7.0" "chex<0.1.91"
After this, the environment was consistent per pip check, the problematic imports found by @PaoloRanzi81 succeeded, and all tests passed.
Thanks @nkbranigan. Congratulations, you found a simpler workaround! Although the new workaround is using jax==0.6.2 instead of jax==0.7.1, the installation procedure is quicker. I have tested on both tutorials above: they both run without bugs!
For posterity, I attach below the commands to be run at the user's command line (Linux systems):
## script for installing "dynamax" on Ubuntu 24.04.2 LTS operating system on a ASUS ROG Strix with x86_64 architecture and Python 3.12.3
# create virtual environment called "benchmarking"
python3 -m venv benchmarking
# enter the virtual environment
source benchmarking/bin/activate
# install dynamax, matplotlib, pandas, numpy, jax (by downgrading) and chex (by downgrading) with CPU support for JAX
pip install dynamax matplotlib pandas numpy "jax<0.7.0" "chex<0.1.91"
## with GPU support for JAX
#pip install dynamax matplotlib pandas numpy "jax[cuda12]<0.7.0" "chex<0.1.91"
EDIT: after further testing, I realized that the solution of downgrading "jax" and "chex" make possible to run the two above tutorial by CPU only. By using the added command "jax[cuda12]<0.7.0" the tutorial with Kalman filter + Hamiltonian Monte Carlo breaks with the following error: XlaRuntimeError: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error.
I am afraid the only way to run both tutorials on GPU for now is to use the cumbersome workaround which I have built few weeks ago.
Hi @slinderman, would it be possible to create a new release that replaces the tensorflow_probability dependency with tfp-nightly? I think this would fix the issue for now and make pip install dynamax work
Hi @slinderman, would it be possible to create a new release that replaces the
tensorflow_probabilitydependency withtfp-nightly? I think this would fix the issue for now and makepip install dynamaxwork
I second the sentiment 😊