Sri Hari Krishna Narayanan
Sri Hari Krishna Narayanan
I just want to add that JAX [supports](https://github.com/google/jax/blob/main/jax/experimental/ode.py) this: Indeed, their implementation is general: ``` @partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3, 4)) def _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args):...
I think that this is resolved by #219
Do you have an MWE?
From a pedagogical standpoint, I am uneasy with this example. because `function forward_func(state, fld_old, fld_now, dt, M)` is differentiated using: `autodiff( forward_func, Duplicated(state_out, dout_old), Duplicated([Tbar; Sbar], din_now), Duplicated([Tbar; Sbar], din_old),...
Changing the names would be helpful for sure. Thanks @swilliamson7. I agree that two separate implementations for two different purposes based on the same underlying physical model makes sense.
Hi. Thanks to @onurdanaci for asking the question and to @lockwo for pointing this question out to me. Apologies to @patrick-kidger for not posting the issue earlier (I am the...