diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Warning for solving example ODE from docu

Open MaAl13 opened this issue 1 year ago • 3 comments

The following example code from the docu produces a warning. How can i get rid of it?

import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5


def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = d_prey, d_predator
    return d_y


term = ODETerm(vector_field)
solver = Tsit5()
t0 = 0
t1 = 140
dt0 = 0.1
y0 = (10.0, 10.0)
args = (0.1, 0.02, 0.4, 0.02)
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat)

Warning:

'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release. warnings.warn(

MaAl13 avatar Jun 20 '24 14:06 MaAl13

This is an upstream issue we spotted in recent versions of JAX -- see https://github.com/google/jax/issues/21824.

This is worked around on Diffrax's main as of https://github.com/patrick-kidger/diffrax/pull/440, so the warning will go away in our next release.

In the mean time you can filter out this warning with warnings.simplefilter.

patrick-kidger avatar Jun 20 '24 18:06 patrick-kidger

Hi @patrick-kidger, should this be fixed with diffrax=0.6.0, or will it be for a later release? It seems I still see the warning on jax==0.4.30 and diffrax=0.6.0.

gautierronan avatar Jul 03 '24 10:07 gautierronan

Right, so there were two different sources of this warning. One coming from Diffrax (fixed as above), and another from Equinox (fixed on HEAD but not yet released). For now it should be safe to ignore the error. At some point soon I'll do another release of Equinox and then that one will be squashed too :)

patrick-kidger avatar Jul 05 '24 06:07 patrick-kidger