diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Type Error in Neural ODE Example

Open moritz-laber opened this issue 9 months ago • 11 comments

Hi,

I've been using jax 0.4.35 and cuda 11.2 with diffrax 0.6.0 and equinox 0.11.8 to train various types of neural ODEs. This worked very well so far. Thanks for the great packages.

However, after upgrading to jax 0.6.0 and cuda 12.8 with diffrax 0.7.0 and equinox 0.12.1, I can no longer compute gradients through ODE solvers. The forward pass works but the gradient computation with eqx.filter_value_and_grad results in a TypeError. To be precise, I receive: TypeError: Argument 'Zero(ShapedArray(float0[32]))' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type . I have tried both the Tsit5() and Dopri5() solvers. The shape of the array in the TypeError is the same as the batch size.

This occurs not only in my own models, but also when I try to run the neural ODE tutorial from https://github.com/patrick-kidger/diffrax/blob/main/examples/neural_ode.ipynb.

Any advice/ideas on how to resolve this would be great! Many thanks in advance.

moritz-laber avatar Apr 26 '25 20:04 moritz-laber

This sounds like this JAX bug we recently discovered: https://github.com/patrick-kidger/optimistix/issues/129

This is getting fixed. I'll check the example tomorrow to be sure, but this is likely the same combination of grad-of-vmap-of-custom-jvp.

If it's the same bug, then I suggest you just wait until the next release of JAX to upgrade.

johannahaffner avatar Apr 26 '25 22:04 johannahaffner

We're indeed looking at the same thing.

johannahaffner avatar Apr 27 '25 13:04 johannahaffner

Great! Thanks for looking into this. I'll wait for the next JAX release then.

moritz-laber avatar Apr 27 '25 15:04 moritz-laber

I am getting the same error while running the demo neural_cde.ipynb. Did you find a fix for that? TypeError: Argument 'Zero(ShapedArray(float0[32]))' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type

mbelalsh avatar May 15 '25 23:05 mbelalsh

This is a known issue with JAX 0.6.0. It's already fixed upstream so presumably the next JAX release will include this fix.

https://github.com/jax-ml/jax/issues/28144

patrick-kidger avatar May 16 '25 08:05 patrick-kidger

@patrick-kidger I have two questions:

  1. Is there any way to use any version of Diffrax with a compatible version of JAX until we get a new release?
  2. I am working with torchcde, but it is super slow. Do you think Diffrax will decrease the training time? And by how much?

mbelalsh avatar May 19 '25 04:05 mbelalsh

Hi @mbelalsh,

you can install the specific versions you want to use, preferably in a virtual environment (e.g. with pip install jax==0.5.2). Then you have an environment with a working configuration for a specific project, and decide when you upgrade.

For torchcde: yes, you can expect this to be much faster.

johannahaffner avatar May 19 '25 06:05 johannahaffner

@johannahaffner Thanks. pip install "jax[cuda12]==0.5.2" works with diffrax 0.4.0. The example neural_cde.ipynb is working now.

mbelalsh avatar May 19 '25 06:05 mbelalsh

Nice! It should also work with recent diffrax (we're at 0.7.0).

For your second question - as per our benchmarks, diffrax is multiple times faster than torchdiffeq. It is probably reasonable to expect the same for torchcde.

johannahaffner avatar May 19 '25 06:05 johannahaffner

@johannahaffner Thanks. I was getting a warning with diffrax 0.4.0. By upgrading to diffrax 0.7.0, I no longer see that warning.

mbelalsh avatar May 19 '25 06:05 mbelalsh

For me upgrading to jax 0.6.1 solved this problem.

moritz-laber avatar May 22 '25 12:05 moritz-laber