First DDE version
New files :
-
discontinuity.pythat does the root finding during integration steps.
Modified files:
-
integrate.pychanged a bit of the code but essentially looks the same but with moreifstatements. There is also the discontinuity handling before each integration step done. Added 2 new arguments to_State(discontinuities,discontinuities_save_index) -
constant.pydoes the discontinuity checking and returns the next integration step. But as said in WIP theprevbeforeandnextafterare done in theloop()
Followed your suggestion regarding dropping y0_history and putting it in y0. However by doing this we must pass y0 to the loop function now. Haven't done the PyTree handling of delays yet. Only works for constant stepsize controller, doing adaptive now.
PS : I dont have the save saving format as you so terms.py shows some deletion and addition for no reason ....
Boilerplate code for a dde :
def vector_field(t, y, args, *, history):
return 1.8 * y * (1 - history[0])
delays = [lambda t, y, args: 1]
y0_history = lambda t: 1.2
discontinuity = (0.0,)
made_jump = discontinuity is None
t0, t1 = 0.0, 100.0
ts = jnp.linspace(t0, t1, 1000)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(vector_field),
diffrax.Dopri8(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=y0_history,
max_steps=2 ** 16,
stepsize_controller=diffrax.ConstantStepSize(),
saveat=diffrax.SaveAt(ts=ts, dense=True),
delays=delays,
discontinuity=discontinuity,
made_jump=made_jump,
)
plt.plot(sol.ts, sol.ys)
plt.show()
Okay, so there's a lot of spurious changes here, mostly due to unneeded formatting changes. Take a look at CONTRIBUTING.md, and in particular the pre-commit hooks. These will autoformat etc. the code. I'll be able to do a proper review then.
Regarding passing y0_history into loop: I'm thinking what we should do is something like:
def diffeqsolve(y0, delays, ...):
if delays is None:
y0_history = None
else:
y0_history = y0
y0 = y0_history(t0)
adjoint.loop(..., y0_history=y0_history)
so that internally we still disambiguate between y0 and y0_history.
Regarding the changes to constant step sizing: hmm, this seems strange to me. I don't think we should need to change any stepsize controller at all. I think the stepsize controller changes we need to make (due to discontinuities) should happen entirely within integrate.py, so that they can apply to all stepsize controllers. (e.g. even user-specified stepsize controllers, that don't have any special support)
Hello, Im back with some updates !
1/ I used the pre-commit hooks but still have one small spurious change in terms.py :astonished: .
1/ y0 and y0_history are disambiguated !
2/ The controllers are untouched and better for modular code.
3/ Discontinuity handling is done in loop file.
4/ 1 edges cases was found, i haven't thought too much and did a "sloppy" fix for now (https://github.com/patrick-kidger/diffrax/pull/169/files#r991488410). Essentially it comes when we integrate a step from tprev to tnext and the integration bound new_tprev=tnext is right next to a discontinuity forcing to redo the step and having an new interval being [tnext ; tnext + epsilon = discontinuity_jump]. This in hand with the code handling will put tprev > tnext and throw an error.
Regarless, for most cases the solver seems to work
Things not done :
1/ y0_history is still not a PyTree of callable but when that will be done i suppose to change its structure one should use eqx.tree_at ?
Latest commit does what you suggested
- Reverted back to not touching the
state.tnextand such. - Integrated your bissection code into the discontinuity detection.
- Only search for discontinuities if step is rejected.
-
delaysarePyTrees
To do :
- code cleanup ie removing
discontinuity.pythat now is obsolete (will do next commit when something stable is available) - implicit stepping when the
state.tnext - state.tprev > min(delays)- regarding this I checked the julia paper (4.2), I understand the issue there but the explanation on how to handle it is rather opaque i find... (*). In this case, the use of the continuous extension (ie
dense_interpfromdense_infos) makes the overall method implicit even if the discrete method we are using is explicit (this is called overlapping apparently).
- regarding this I checked the julia paper (4.2), I understand the issue there but the explanation on how to handle it is rather opaque i find... (*). In this case, the use of the continuous extension (ie
Comments :
- Regarding the first commit where lots of thing where changed, I wasn't aware of the 2 point of views (philosophies) on how to integrate DDEs. The first is to track the discontinuities and the other relies on error estimation of the solver's method. On the latter discontinuity tracking is somewhat useless. Julia uses therefore the second idea.
- Since the implicit stepping (cf *) is not done some integrations are faulty.
Latest commit has :
- added a
discont_updatein the discontinuity checking part to make sure we correctly updatestate.discontinuity - added a
discont_checkargument indiffeqsolveto check for breaking points at each step - implemented the implicit step in
history_extrapolation_implicit - removed the
discontinuity.pyso everything is inintegrate.py - enhanced discontinuity tracking with another functionality to check also for roots in subintervals of the
[tprev, tnext], this helps to not miss any potential jumps that would otherwise be undetected. Not sure, but i think thevmaping here doesn't destroy the computational gain of theunvmap?
Latest commit updates some remarks of code comment after latest review https://github.com/patrick-kidger/diffrax/pull/169#pullrequestreview-1159470301.
I've bundled together as you said the delays term together for an easier API.
class _Delays(eqx.Module):
delays: Optional[PyTree[Callable]]
initial_discontinuities: Union[Array, Tuple]
max_discontinuities: Int
recurrent_checking: Bool
rtol: float
atol: float
eps : float
delays is the regular delayed term definition, initial_discontinuities corresponds to the discontinuities that the user needs to give to get proper DDE integration (I refer you to my response https://github.com/patrick-kidger/diffrax/pull/169#discussion_r1007989927) , discont_checking is now recurrent_checking and the atol, rtolare for the Newton solver in the implicit step and eps the tolerance error or the step.
What you suggested is great for later iterations because we can just slap any new arguments into _Delays !
What needs to be done :
- https://github.com/patrick-kidger/diffrax/pull/169#discussion_r1007651902 Didn't transfer code snipped because I wasn't sure which one it was.
- https://github.com/patrick-kidger/diffrax/pull/169#discussion_r1007661929 Finding a way to not redo the
solver.stepwhen we do implicit steps. - Take a look into the coding of
dense_infosanddense_tsand in-place updates
Okay, let's focus on this bit before we move on to discussing the discontinuity handling below. I'll leave you to make the changes already discussed here. Let me know if any of them aren't clear.
Sounds good, I'll take care of the first bullet point later, the second one should be ok on my side however i'd like to have your take on the third one with the in-place operations (for dense_ts, dense_infos) since this is some very sharp JAX bit :hocho: .
Sure thing. I'm suggesting that _HistoryVectorField should look something like this:
class _HistoryVectorField(eqx.Module):
...
tprev: float
tnext: float
dense_info: PyTree[Array]
interpolation_cls: Type[AbstractLocalInterpolation]
def __call__(self, t, y, args):
...
if self.dense_interp is None:
...
else:
for delay in delays:
delay_val = delay(t, y, args)
alpha_val = t - delay_val
is_before_t0 = alpha_val < self.t0
is_before_tprev = alpha_val < self.tprev
at_most_t0 = jnp.min(alpha_val, self.t0)
t0_to_tprev = jnp.clip(alpha_val, self.t0, self.tprev)
at_least_tprev = jnp.max(alpha_val, self.tprev)
step_interpolation = self.interpolation_cls(
t0=self.tprev, t1=self.tnext, **self.dense_info
)
switch = jnp.where(is_before_t0, 0, jnp.where(is_before_tprev, 1, 2))
history_val = lax.switch(switch, [lambda: self.y0_history(at_most_t0),
lambda: self.dense_interp(t0_to_tprev),
lambda: step_interpolation.evaluate(at_least_tprev)])
history_vals.append(history_val)
...
return ...
And then when it is called inside the implicit routine:
def body_fun(val):
dense_info, ... = val
...
_HistoryVectorField(..., state.tprev, state.tnext, dense_info, solver.interpolation_cls)
...
return new_dense_info, ...
Latest commit should have handle all of the issues mentionned above.
- Not do an extra step after doing the implicit step by integrating de-facto the explicit step in the
lax.for_loopof the functionhistory_extrapolation_implicitwith the conditional your proposed (https://github.com/patrick-kidger/diffrax/pull/169#discussion_r1012032781) - In-place update handled with
_HistoryVectorField
Discussion/Bottleneck for implicit step
Regarding the implicit step we have a issue when it comes to large steps because an step_interpolation with only 2 points won't suffice. This depends on the fact that the snippet below is indeed a 2 point interpolation :
step_interpolation = self.interpolation_cls(t0=self.tprev, t1=self.tnext, **self.dense_info)
To elaborate a bit more, if we have an implicit step from state.tprev to state.tnext. Our associated history function for the equation y'(t) = f(t, y(t-tau)) will be known from state.tprev up to state.tprev + tau and from state.prev+tau to state.tnext we need its extrapolation. If the history function in the interval [state.prev+tau : state.tnext] is non monotonous (we could image half period of a sinus for example) well a 2 point extrapolation won't capture the function correctly but we would need rather 10 points lets say to get a good estimate. To this regard we would need also to change the conditioning of your implict step from
_pred = (((y - y_prev) / y) > delays.eps).any()
to something that checks the MSE of the extrapolated history function before and after the integration step. Not sure with this in mind the _HistoryVectorField from https://github.com/patrick-kidger/diffrax/pull/169#issuecomment-1302691266 as is will do the trick.
This also impacts too the population of the ys in dense_ts since the values are interpolated with the computed steps of y. (If we go with an implicit step from state.tprev to state.tnext and we have to save some points in between, the current procedure is to use a 2 point (yprev=state.y and ynext=y) interpolation right ?) With that being said, this discussion is only relevant if the time mesh that we have (ie dense_ts) is precise enough.
Our associated history function for the equation y'(t) = f(t, y(t-tau)) will be known from state.tprev up to state.tprev + tau
I don't think this is true. Anything after tprev hasn't been evaluated yet. The whole [tprev, tnext] region is initialised as an extrapolation from the previous step.
Regarding 2-point interpolations: this isn't the case for most solvers. Each solver evaluates various intermediate quantities during its step (e.g. the stages of an RK solver) and these also feed into the interpolation.
Even if were, though, I don't think it matters: we just need to converge to a solution of the implicit problem.
Ok, this makes sense, so i'll take back what I said in my bottleneck "Discussion/Bottleneck for implicit step", thanks for the clarification ! The _pred condition should be on the dense_info then !
Edit : _pred condition as it seems to be working fine !
For this part of the code, i'd say its ready for a review before moving to discontinuity checking ! (unconstrained time stepping works well)
Relevant changes are made :
- All the code is now managed with
PyTreeoperations (i think) - Moved
Delaysand implicit step into a new file (but there is a circular import there ... not sure if this is an issue) - Removed
epsfeatures of the implicit stepping
In order to get something backprop compatible in history_extrapolation_implicit I'll need to use your bounded_while_loop instead of lax.while_loopin the meantime until your PR is merged ?
Great news !
Im not familiar with the wiring of the implicit_jvp up for some pointers.
Other then that relevant changes were made
Alright, on to the next block of code!
As for implicit_jvp -- I actually have some in-progress work that may simplify this. If you'd like to be able to backpropagate through this code soon-ish then I can expand on what I mean here? (But if it's not a rush then I'm happy to put a pin in that for now.)
- Relevant changes were made for the last code block, everything is in
delays.pynow ! - Added 2 new attributes for
Delays:nb_sub_intervalsandmax_steps - Lots of wrapper and deletion makes the code more readable !
As for implicit_jvp, backpropagating throught this part would be great !
Looking into this https://github.com/patrick-kidger/diffrax/pull/169#discussion_r1025901003 a bit deeper now
Thanks for the review Patrick :) , so from what I understood what i need to do is to create a new class DDEImplicitNonLinearSolve (for e.g) that inherits from AbstractNonlinearSolver in order to use implicit_jvp. Basically i'll need to rewrite the _solve method and then in history_extrapolation_implicit call in an instance of DDEImplicitNonLinearSolve.
class DDEImplicitNonLinearSolve(AbstractNonlinearSolver):
def _solve(
self,
fn: callable, # would be terms
x : Pytree, # here would be state.y
nondiff_args: PyTree, # here would be all the other args from current history_extrapolation_implicit(...)
diff_args: PyTree,
)
def history_extrapolation_implicit(...):
nonlinearsolver = DDEImplicitNonLinearSolve(...)
results = nonlinearsolver(terms, y, args).root
y, y_error, dense_info, solver_state, solver_result = results
return y, y_error, dense_info, solver_state, solver_result
If thats the case could you explain how you usually work with your non_diff_args and diff_args since I think its more native to JAX's jvp/vjp and are later on used with implicit_jvp and _rewrite, etc... ?
Right, something like what you're saying looks correct! (Although I'd suggest calling it e.g. FixedPointSolver instead.)
You should just use args = eqx.combine(diff_args, nondiff_args). The API here is a bit inelegant, and as above I have a planned rewrite for this, in a few months.
Gotcha, I might be running into an issue with regards to implicit_jvp. I think I need to rewrite the _rewrite(root, _, diff_args, closure) function since it returns fn(root,args) and fn is a Callable.
As of now the implicit/explicit stepping looks like this:
def history_extrapolation_implicit(...):
nonlinearsolver = FixedPointSolver()
nonlinear_args = (
implicit_step,
dense_interp,
solver,
delays,
t0,
y0_history,
state
)
results = nonlinearsolver(terms, state.y, nonlinear_args).root
y, y_error, dense_info, solver_state, solver_result = results
return y, y_error, dense_info, solver_state, solver_result
I call upon the solver :
class FixedPointSolver(AbstractNonlinearSolver):
max_steps: Optional[Int] = 10
norm: Callable = rms_norm
def __post_init__(self):
if self.max_steps is not None and self.max_steps < 2:
raise ValueError("max_steps must be at least 2.")
def _solve(
self,
fn: AbstractTerm,
x: PyTree,
jac: Optional[LU_Jacobian],
nondiff_args: PyTree,
diff_args: PyTree,
) -> Tuple[PyTree, RESULTS]:
args = eqx.combine(nondiff_args, diff_args)
(
implicit_step,
dense_interp,
solver,
delays,
t0,
y0_history,
state,
) = args
def cond_fn(val):
_, _, _, _, _, pred, step = val
return (implicit_step & pred) | (
jnp.invert(implicit_step) & (step == 0)
)
def body_fn(val):
y_prev, _, dense_info, _, _, _, step = val
terms_ = bind_history(...)
(y, y_error, dense_info, solver_state, solver_result) = solver.step(...)
....
return (y,y_error, dense_info, solver_state,solver_result, _pred,step + 1,)
_init_val = (...)
(y, y_error, .... ) = lax.while_loop(cond_fn, body_fn, _init_val)
y_error = jtu.tree_map(
lambda _y_error: jnp.where(final_step < 10, _y_error, jnp.inf),
y_error,
)
root = (y, y_error, dense_info, solver_state, solver_result)
result = jnp.where(
final_step < self.max_steps,
RESULTS.successful,
RESULTS.implicit_nonconvergence,
)
return NonlinearSolution(root=root, num_steps=final_step, result=result)
On the one hand, If I use fn=terms then the _rewrite function will be incompatible. On the other hand, if I only give instead its vector_field then I loose all the structure associated to terms_ (I think reconstructing the VectorFieldWrapper term isn't the bestcomputationally). Let me know, what you think/ if I'm not clear or what I wrote doesn't make sense !
You shouldn't have FixedPointSolver be an AbstractNonlinearSolver. (At least naively.) This uses implicit_jvp to backprop through the implicit problem f(x, args) = 0. But a fixed-point iteration solves the implicit problem f(x, args) = x. You either need to convert from one problem to the other, or to just use implicit_jvp directly, with an appropriate fn_rewrite.
In terms of what to pass as fn: you should be able to pass terms as part of args, and then have fn be some (global) function that unpacks args, and calls terms.
Ok so this is what I came up with by transforming the fixed point problem in a root finding one.
def history_extrapolation_implicit(terms, delays, dense_info, dense_interp,
solver, direction, t0, tprev, tnext, y0_history):
nonlinearsolver = FixedPointSolver()
nonlinear_args = (terms,implicit_step,dense_interp,solver,delays,t0,y0_history,state,args)
def fn(dense_info, args):
(terms, _, dense_interp, solver, delays, t0, y0_history, state, vf_args) = args
terms_ = bind_history(...)
return dense_info[-1] - terms_.vf(state.tnext, dense_info[-1], vf_args)
init_guess = jtu.tree_map(
lambda x: x[state.dense_save_index - 1], state.dense_infos
)
results = nonlinearsolver(fn, init_guess, nonlinear_args).root
y, y_error, dense_info, solver_state, solver_result = results
return y, y_error, dense_info, solver_state, solver_result
when i run a quick backprop example, I see that our dense_info that is needed for the solved is traced in the backprop (with the primals and tangents) and therefore can't satisfy _HistoryVectorField's :
step_interpolation = self.interpolation_cls(
t0=self.tprev, t1=self.tnext, **self.dense_info
)
not sure what to think of it.
Moreover the fact that dense_info in fn is used for history binding and then we actually return the "real" function fn used for implicit_jvp that only takes dense_info["y1"].
With that being said, I don't know if I can fit currently implicit_jvp with the implicit stepping or i should try to follow the same custom_jvp procedure from http://implicit-layers-tutorial.org/implicit_functions/ ?
I'm afraid I'm not sure what you mean about dense_info not satisfying _HistoryVectorField? Are you getting an error message / can you elaborate?
In any case note that we would really like to compute a fixed point wrt just the dense info at the current step, not the dense info across the entire length of the solve so far. You need to arrange for the first argument to fn to be just the dense info for the current time step.
Likewise, I'm not sure I understand your second question?
With the current setup
def history_extrapolation_implicit(terms, delays, dense_info, dense_interp, solver, direction, t0, tprev, tnext, y0_history): nonlinearsolver = FixedPointSolver() nonlinear_args = (terms,implicit_step,dense_interp,solver,delays,t0,y0_history,state,args) def fn(dense_info, args): (terms, _, dense_interp, solver, delays, t0, y0_history, state, vf_args) = args terms_ = bind_history(...) return dense_info[-1] - terms_.vf(state.tnext, dense_info[-1], vf_args) init_guess = jtu.tree_map( lambda x: x[state.dense_save_index - 1], state.dense_infos ) results = nonlinearsolver(fn, init_guess, nonlinear_args).root y, y_error, dense_info, solver_state, solver_result = results return y, y_error, dense_info, solver_state, solver_result
and with the plugging in implicit_jvp inside the FixedPointIterator (via inheritence of the abstract solver class) from above. Doing a quick dummy backprop I get an error message so that must mean that I probably didn't do the right thing. The function fn gets dense_info are the first arguments and then the rest needed for computation in args.
The error yielded is :
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/monsel/Desktop/Project/diffrax/backprop_test.py", line 80, in <module>
value, model, opt_state = make_step(ts, yi, model, opt_state, optim)
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/jit.py", line 96, in __call__
return __self._fun_wrapper(False, args, kwargs)
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/jit.py", line 92, in _fun_wrapper
dynamic_out, static_out = self._cached(dynamic, static)
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/jit.py", line 39, in fun_wrapped
out = fun(*args, **kwargs)
File "/home/monsel/Desktop/Project/diffrax/backprop_test.py", line 39, in make_step
loss, grads = grad_loss(model, ti, yi)
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/grad.py", line 31, in __call__
return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/grad.py", line 28, in fun_value_and_grad
return __self._fun(_x, *_args, **_kwargs)
File "/home/monsel/Desktop/Project/diffrax/backprop_test.py", line 33, in grad_loss
y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
File "/home/monsel/Desktop/Project/diffrax/models.py", line 103, in __call__
solution = diffrax.diffeqsolve(
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/jit.py", line 96, in __call__
return __self._fun_wrapper(False, args, kwargs)
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/jit.py", line 92, in _fun_wrapper
dynamic_out, static_out = self._cached(dynamic, static)
File "/home/monsel/miniconda3/envs/dde/lib/python3.9/site-packages/equinox/jit.py", line 39, in fun_wrapped
out = fun(*args, **kwargs)
File "/home/monsel/Desktop/Project/diffrax/diffrax/integrate.py", line 1021, in diffeqsolve
final_state, aux_stats = adjoint.loop(
File "/home/monsel/Desktop/Project/diffrax/diffrax/adjoint.py", line 78, in loop
return self._loop_fn(**kwargs, is_bounded=True)
File "/home/monsel/Desktop/Project/diffrax/diffrax/integrate.py", line 593, in loop
final_state = bounded_while_loop(
File "/home/monsel/Desktop/Project/diffrax/diffrax/misc/bounded_while_loop.py", line 136, in bounded_while_loop
_, val, _ = _while_loop(_cond_fun, body_fun, init_data, rounded_max_steps, base)
File "/home/monsel/Desktop/Project/diffrax/diffrax/misc/bounded_while_loop.py", line 240, in _while_loop
return lax.scan(_scan_fn, data, xs=None, length=base)[0]
File "/home/monsel/Desktop/Project/diffrax/diffrax/misc/ad.py", line 103, in fn_jvp_wrapper
return fn_jvp(*nondiff_args, diff_args, tang_diff_args)
File "/home/monsel/Desktop/Project/diffrax/diffrax/misc/ad.py", line 170, in _implicit_backprop_jvp
jac_flat_root = jax.jacfwd(_for_jac)(flat_root)
File "/home/monsel/Desktop/Project/diffrax/diffrax/misc/ad.py", line 166, in _for_jac
_out = fn_rewrite(_root, residual, args, closure)
File "/home/monsel/Desktop/Project/diffrax/diffrax/nonlinear_solver/base.py", line 34, in _rewrite
return fn(root, args)
File "/home/monsel/Desktop/Project/diffrax/diffrax/delays.py", line 217, in fn
return dense_info[-1] - terms_.vf(state.tnext, dense_info[-1], vf_args)
File "/home/monsel/Desktop/Project/diffrax/diffrax/term.py", line 378, in vf
return self.term.vf(t, y, args)
File "/home/monsel/Desktop/Project/diffrax/diffrax/term.py", line 183, in vf
return self.vector_field(t, y, args)
File "/home/monsel/Desktop/Project/diffrax/diffrax/term.py", line 159, in __call__
return self.vector_field(t, y, args)
File "/home/monsel/Desktop/Project/diffrax/diffrax/delays.py", line 106, in __call__
step_interpolation = self.interpolation_cls(
TypeError: diffrax.solver.tsit5._Tsit5Interpolation() argument after ** must be a mapping, not tuple
During a DDE integration dense_info (forward pass) has the following form that work for interpolation :
{'k': array([[-0.00445976, 0.01871458],
[-0.00488277, 0.01874425],
[-0.00509373, 0.01875141],
[-0.00531244, 0.01887182],
[-0.0053381 , 0.01889448],
[-0.0053723 , 0.01892434],
[-0.0053742 , 0.01892584]], dtype=float32), 'y0': array([1.184424 , 1.0490271], dtype=float32), 'y1': array([1.179288 , 1.0678301], dtype=float32)}
During the backward pass, if im not mistaken in the current state of thing dense_info get transformed (for tracing ?) into a tuple of the sort that won't fit with the interpolation
(y0 jnp.array(...), 'y1' : jnp.array(...) , 'k' : jnp.array(...))
Since I get this error message I was either probably misusing implicit_jvp or thinking that we need to defined another custom_jvp.
For a small repro with your regular code:
class Func(eqx.Module):
mlp: eqx.nn.MLP
def __init__(
self, data_size, width_size, depth, activation, *, key, **kwargs
):
super().__init__(**kwargs)
self.mlp = eqx.nn.MLP(
in_size=2 * data_size,
out_size=data_size,
width_size=width_size,
depth=depth,
activation=activation,
key=key,
)
def __call__(self, t, y, args, *, history):
return self.mlp(jnp.hstack([y, history[0]]))
class NeuralDDE(eqx.Module):
func: Func
delays: Delays
def __init__(
self, data_size, width_size, depth, activation, delays, *, key, **kwargs
):
super().__init__(**kwargs)
self.func = Func(data_size, width_size, depth, activation, key=key)
self.delays = delays
def __call__(self, ts, y0):
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=lambda t : y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(ts=ts, dense=True),
delays=self.delays,
made_jump=True,
)
return solution.ys
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
return jnp.mean((yi - y_pred) ** 2)
@eqx.filter_jit
def make_step(ti, yi, model, opt_state, optim):
loss, grads = grad_loss(model, ti, yi)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
ts = jnp.linspace(0.0, 10, 100)
delays = Delays(
delays=[lambda t, y, args: 1],
initial_discontinuities=jnp.array([0.0]),
max_discontinuities=2,
recurrent_checking=False,
rtol=10e-3,
atol=10e-6,
)
width, depth, activation = 32, 3, jnn.relu
data_keys = jrandom.split(data_key, dataset_size)
ys = jax.random.normal(key, (10, 100, 1))
_, length_size, data_size = ys.shape
model = NeuralDDE(data_size, width, depth, activation, delays, key=model_key)
optim = optax.adabelief(10e-3)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
for (yi,) in dataloader((ys,), batch_size, key=key):
value, model, opt_state = make_step(ts, yi, model, opt_state, optim)
I added some auxiliary stats for implicit and explicit counting (https://github.com/patrick-kidger/diffrax/pull/169#pullrequestreview-1194907612). I also did some documentation and testing. Documentation will most likely change later on. Let me know if I need to change the location files of julia.
Hi @patrick-kidger , I'm having another go on the DDEs and with all of the changes on Diffrax and Equinox.
I have rebased the code and with it, there is new inner_while_loop/outer_while_loop in loop !
DenseInterpolation (in integrate.py l.238) is inside inner_while_loop/outer_while_loop and all the arguments of DenseInterpolation gets transformed before hand with Equinox's _Buffer object/module and the DenseInterpolation.__post_init__ check fails. I was wondering if there's anyway around it ?
Thanks !
Ah, I'm guessing it's the ts=state.dense_ts and infos=state.dense_infos that are problematic. Try something like:
ts = state.dense_ts[...]
infos=jtu.tree_map(lambda _, x: x[...], dense_info, state.dense_infos)
to unwrap them from buffers.
Note that unwrapping like this is an unsafe thing to do in general! You have to make the promise to only read from locations that have already been written to with an in-place update -- otherwise you'll get incorrect gradients. (I would definitely add some tests that finite differences match the results from jax.grad, just to be sure. FYI we have a finite_difference_jvp here that you could copy-paste.)
Adding [...] did the trick for ts
ts = state.dense_ts[...]
Unfortunately unwrapping the Buffer with other structure's like DenseInfos didn't seem to work.
infos=jtu.tree_map(lambda x: x[...], state.dense_infos)
The problem came from the _pred argument of the _Buffer
However, I found another way
unwrapped_buffer = jtu.tree_leaves(
eqx.filter(state.dense_infos, eqx.is_inexact_array),
is_leaf=eqx.is_inexact_array,
)
unwrapped_dense_infos = dict(zip(state.dense_infos.keys(), unwrapped_buffer))
Notice how I pass in an additional argument in my previous snippet: jtu.tree_map(lambda _, x: x[...], dense_info, state.dense_infos).
This is because each _Buffer object is actually a PyTree, that's sort of masquerading as an array. So to handle this we actually iterate over a different pytree, that happens to have the correct structure.
The use of buffers is a pretty advanced/annoying detail. I'm pondering using something like Quax to create a safer API for this, but that's a long way down the to-do list.
Indeed your right, I just realised that in the documentation yesterday, this makes sense now !
However, we only have access to dense_info after creating DenseInterpolation.
i.e.
dense_interp = DenseInterpolation(
ts=state.dense_ts[...],
infos = jtu.tree_map(lambda _, x: x[...], dense_info, state.dense_infos),
...
)
(
y,
y_error,
dense_info,
solver_state,
solver_result,
) = history_extrapolation_implicit(
...
)
So I would agree that this works if dense_info was available which seems to be not the case. Hence not sure how to do it. Moreover, I did some preliminary testing and doing this https://github.com/patrick-kidger/diffrax/pull/169#issuecomment-1680494510 yields wrong gradients. Nonetheless, the code makes the promise to only read from locations that have already been written to with an in-place update so my unwrapping method is/seems erroneous.
Since I'm starting to see activity here again -- you can track my progress updating Diffrax in https://github.com/patrick-kidger/diffrax/issues/217 and https://github.com/patrick-kidger/diffrax/tree/big-refactor.
(This currently depends on the unreleased versions of Equinox and jaxtyping.)
Mostly there now!