fix for t0==t1 when initializing ys, as opposed to inside the loop
Eliminates code introduced in https://github.com/patrick-kidger/diffrax/pull/494 whereby ys was updated inside of the loop to y0 in the case of t0==t1. Now ys is instead initialized appropriately with y0. This is in response to performance issues observed in https://github.com/patrick-kidger/diffrax/issues/606.
I will investigate these test failures. Patrick can you confirm the adjoint failures are unrelated to this PR? I definitely need to address the event failures though.
Awesome, thank you for putting this together!
As for the test failures, I think this might be due to these changes being made on top of an old version of Diffrax main. In particular I see the following warning:
/home/runner/work/diffrax/diffrax/diffrax/_integrate.py:1258: DeprecationWarning: shape requires ndarray or scalar arguments, got <class 'jax._src.api.ShapeDtypeStruct'> at position 0. In a future JAX release this will be an error.
1588
(max_steps,) + jnp.shape(x), jnp.inf, dtype=x.dtype
Which does not reflect the current state or line number of that line:
https://github.com/patrick-kidger/diffrax/blob/2fafbc7591506834341822a869d9eefbdfefce82/diffrax/_integrate.py#L1268
That might explain the failures? (Which I can't reproduce locally.)
Hmm it seems I was wrong: those test failures do directly result from my changes (I do think I based my changes on the most recent version of main, I see commit 2fafbc7591506834341822a869d9eefbdfefce82 as the most recent change before mine. You might be referring to the below line?) https://github.com/patrick-kidger/diffrax/blob/36c1d3e4b05456332afa1923340cd924f76da134/diffrax/_integrate.py#L1257-L1259
In particular, commenting out the below section (which actually makes the replacement with the correct value of y0 if t0==t1), the test_adjoint.py tests pass locally, whereas they do fail locally if I leave this section uncommented. I'm a little baffled by these errors, which don't seem to point to a particular line: maybe its upset that we are asking about t0==t1 in where?
https://github.com/patrick-kidger/diffrax/blob/36c1d3e4b05456332afa1923340cd924f76da134/diffrax/_integrate.py#L1216-L1222
Closing in favour of #618.