diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Save fix for `to==t1`

Open dkweiss31 opened this issue 1 year ago • 1 comments

Addresses edge case raised in https://github.com/patrick-kidger/diffrax/issues/488 when t0 == t1 and saveat.ts is not None. Additionally if saveat.t0 is True then those values were not updated either, which should be addressed by this PR. I've additionally included a test for this case.

WRT the implementation: while a loop is not very nice since everything could in principle be done in parallel, the below did not work for the ts part due to dynamic slicing errors. Let me know if there is a nicer workaround I could try :)

if subsaveat.ts is not None:
    _ts = subsaveat.ts
    save_idx = save_state.save_index
    ts = save_state.ts.at[save_idx: save_idx + len(_ts)].set(_ts)
    _ys = [subsaveat.fn(t1, yfinal, args)] * len(_ts)
    ys = save_state.ys.at[save_idx: save_idx + len(_ts)].set(_ys)
    save_state = SaveState(
         saveat_ts_index=save_idx + len(_ts),
         ts=ts,
         ys=ys,
         save_index=save_idx + len(_ts),
     )

dkweiss31 avatar Aug 18 '24 15:08 dkweiss31

To address some failing tests re reverse mode differentiation I converted it to a while_loop, but I'm still seeing some failed tests. Converting this to a draft for now

dkweiss31 avatar Aug 21 '24 19:08 dkweiss31

@patrick-kidger sorry for the long delay! I think the PR is ready for review now. All tests pass except for one of the tqdm progress bar tests involving jit: I'm not at all sure what is going on there?

Additionally I wanted to draw your attention to the line I wrote on line 773:

def _save_ts_impl(ts, fn, _save_state):
    def _cond_fun(__save_state):
        return __save_state.saveat_ts_index < len(_save_state.ts)

where I had to use _save_state.ts instead of ts in the conditional check because saveat_ts_index can already be 1 if _save_state.t0==True. So if I used ts, then the last entry doesn't get updated. This doesn't mirror exactly what's happening on lines 421-427, so I just wanted to briefly mention it.

dkweiss31 avatar Nov 13 '24 13:11 dkweiss31

Awesome! Can you rebase on top of dev (+make that the PR target branch) and I'll do a review? :) (I think this should also fix the tqdm-jit test.)

patrick-kidger avatar Nov 17 '24 18:11 patrick-kidger

Ok, done I think!

dkweiss31 avatar Nov 18 '24 23:11 dkweiss31

Hi @patrick-kidger ! I've tried again and as you suspected no loop is necessary. I went in a slightly different direction from what you suggested, let me know your thoughts :). I've also added the vmap test, lmk if this is what you had in mind or if it was something different

dkweiss31 avatar Nov 27 '24 18:11 dkweiss31

Awesome! This LGTM. Thank you for so carefully catching the edge-cases and the edge-cases-of-edge-cases. I'll be doing a new release of Diffrax shortly, which will include this fix :)

patrick-kidger avatar Dec 06 '24 07:12 patrick-kidger