Warning for solving example ODE from docu
The following example code from the docu produces a warning. How can i get rid of it?
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
def vector_field(t, y, args):
prey, predator = y
α, β, γ, δ = args
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
d_y = d_prey, d_predator
return d_y
term = ODETerm(vector_field)
solver = Tsit5()
t0 = 0
t1 = 140
dt0 = 0.1
y0 = (10.0, 10.0)
args = (0.1, 0.02, 0.4, 0.02)
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat)
Warning:
'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release. warnings.warn(
This is an upstream issue we spotted in recent versions of JAX -- see https://github.com/google/jax/issues/21824.
This is worked around on Diffrax's main as of https://github.com/patrick-kidger/diffrax/pull/440, so the warning will go away in our next release.
In the mean time you can filter out this warning with warnings.simplefilter.
Hi @patrick-kidger, should this be fixed with diffrax=0.6.0, or will it be for a later release? It seems I still see the warning on jax==0.4.30 and diffrax=0.6.0.
Right, so there were two different sources of this warning. One coming from Diffrax (fixed as above), and another from Equinox (fixed on HEAD but not yet released). For now it should be safe to ignore the error. At some point soon I'll do another release of Equinox and then that one will be squashed too :)