Intermediate saved values are sometimes `inf`
Hi Patrick! I've run into an issue over in https://github.com/dynamiqs/dynamiqs/issues/666 when t0=t1 and I try to save intermediate values. It seems to be independent of the stepsize_controller that I use (adaptive or constant). Here is a minimal example using constant steps.
import diffrax as dx
import jax.numpy as jnp
term = dx.ODETerm(lambda t, y, _: y)
y0 = jnp.array([1.0])
ts = jnp.array([0.0, 0.0])
saveat = dx.SaveAt(subs=[dx.SubSaveAt(ts=ts), dx.SubSaveAt(t1=True)])
solution = dx.diffeqsolve(
term,
dx.Tsit5(),
ts[0],
ts[-1],
0.1,
y0,
saveat=saveat,
)
print(solution.ys[0]) # [[inf] [inf]]
print(solution.ys[1]) # [[1.]]
Ah! Good catch. It seems that we don't try to fill in SaveAt(ts=...) in the t0=t1 case. This case means that we never enter our integration loop and so we never trigger any of the code for saving SaveAt(ts=...):
https://github.com/patrick-kidger/diffrax/blob/7384bfa1cd3222c2c5d8c705907e60bbf71587ec/diffrax/_integrate.py#L408-L456
I suppose we should add a special case in our integration loop that explicitly handles this case. (Something like ys = lax.cond(t0 == t1, lambda: jnp.full(ts, y0), ...) ?)
I'd be happy to take a PR on this.
Gotcha! Nice I'd be happy to give this a go. I've been meaning to learn about while loops and buffers so no time like the present!
Great! And the good news is that I think this gets to happen before/after the loop (I mispoke a little above), so there won't be any need to change the integration loop itself.