diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Intermediate saved values are sometimes `inf`

Open dkweiss31 opened this issue 1 year ago • 3 comments

Hi Patrick! I've run into an issue over in https://github.com/dynamiqs/dynamiqs/issues/666 when t0=t1 and I try to save intermediate values. It seems to be independent of the stepsize_controller that I use (adaptive or constant). Here is a minimal example using constant steps.

import diffrax as dx
import jax.numpy as jnp

term = dx.ODETerm(lambda t, y, _: y)
y0 = jnp.array([1.0])
ts = jnp.array([0.0, 0.0])
saveat = dx.SaveAt(subs=[dx.SubSaveAt(ts=ts), dx.SubSaveAt(t1=True)])

solution = dx.diffeqsolve(
    term,
    dx.Tsit5(),
    ts[0],
    ts[-1],
    0.1,
    y0,
    saveat=saveat,
)
print(solution.ys[0])  # [[inf] [inf]]
print(solution.ys[1])  # [[1.]]

dkweiss31 avatar Aug 16 '24 13:08 dkweiss31

Ah! Good catch. It seems that we don't try to fill in SaveAt(ts=...) in the t0=t1 case. This case means that we never enter our integration loop and so we never trigger any of the code for saving SaveAt(ts=...):

https://github.com/patrick-kidger/diffrax/blob/7384bfa1cd3222c2c5d8c705907e60bbf71587ec/diffrax/_integrate.py#L408-L456

I suppose we should add a special case in our integration loop that explicitly handles this case. (Something like ys = lax.cond(t0 == t1, lambda: jnp.full(ts, y0), ...) ?)

I'd be happy to take a PR on this.

patrick-kidger avatar Aug 16 '24 16:08 patrick-kidger

Gotcha! Nice I'd be happy to give this a go. I've been meaning to learn about while loops and buffers so no time like the present!

dkweiss31 avatar Aug 16 '24 21:08 dkweiss31

Great! And the good news is that I think this gets to happen before/after the loop (I mispoke a little above), so there won't be any need to change the integration loop itself.

patrick-kidger avatar Aug 17 '24 06:08 patrick-kidger