diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

fix for t0==t1 when initializing ys, as opposed to inside the loop

Open dkweiss31 opened this issue 10 months ago • 3 comments

Eliminates code introduced in https://github.com/patrick-kidger/diffrax/pull/494 whereby ys was updated inside of the loop to y0 in the case of t0==t1. Now ys is instead initialized appropriately with y0. This is in response to performance issues observed in https://github.com/patrick-kidger/diffrax/issues/606.

dkweiss31 avatar Apr 04 '25 00:04 dkweiss31

I will investigate these test failures. Patrick can you confirm the adjoint failures are unrelated to this PR? I definitely need to address the event failures though.

dkweiss31 avatar Apr 04 '25 11:04 dkweiss31

Awesome, thank you for putting this together!

As for the test failures, I think this might be due to these changes being made on top of an old version of Diffrax main. In particular I see the following warning:

/home/runner/work/diffrax/diffrax/diffrax/_integrate.py:1258: DeprecationWarning: shape requires ndarray or scalar arguments, got <class 'jax._src.api.ShapeDtypeStruct'> at position 0. In a future JAX release this will be an error.
1588
    (max_steps,) + jnp.shape(x), jnp.inf, dtype=x.dtype

Which does not reflect the current state or line number of that line:

https://github.com/patrick-kidger/diffrax/blob/2fafbc7591506834341822a869d9eefbdfefce82/diffrax/_integrate.py#L1268

That might explain the failures? (Which I can't reproduce locally.)

patrick-kidger avatar Apr 04 '25 13:04 patrick-kidger

Hmm it seems I was wrong: those test failures do directly result from my changes (I do think I based my changes on the most recent version of main, I see commit 2fafbc7591506834341822a869d9eefbdfefce82 as the most recent change before mine. You might be referring to the below line?) https://github.com/patrick-kidger/diffrax/blob/36c1d3e4b05456332afa1923340cd924f76da134/diffrax/_integrate.py#L1257-L1259

In particular, commenting out the below section (which actually makes the replacement with the correct value of y0 if t0==t1), the test_adjoint.py tests pass locally, whereas they do fail locally if I leave this section uncommented. I'm a little baffled by these errors, which don't seem to point to a particular line: maybe its upset that we are asking about t0==t1 in where? https://github.com/patrick-kidger/diffrax/blob/36c1d3e4b05456332afa1923340cd924f76da134/diffrax/_integrate.py#L1216-L1222

dkweiss31 avatar Apr 05 '25 16:04 dkweiss31

Closing in favour of #618.

patrick-kidger avatar Jul 27 '25 14:07 patrick-kidger