diffrax
diffrax copied to clipboard
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
When trying to run a basic example: ``` from diffrax import diffeqsolve, ODETerm, Dopri5 import jax.numpy as jnp from jax.lax import cond Array.set_default_backend('jax') def f(t, y, args): return -y term...
At the moment all of our SDE solvers are Levy-area-free. It should be relatively straightforward to add support for different kinds of Levy area, by extending the `evaluate` interface.
SDE: - [x] Milstein - [x] Euler-Heun - [x] SRKs (Additive) - [ ] Talay (Ito commutative) Implict: - [ ] BDF - [ ] Rosenbrock methods - [x] IMEX...
These do some pretty deep magic to get good performance, and [some of it is probably quite fragile](https://github.com/patrick-kidger/diffrax/blob/15c2ab9145ab0eda69af22d98d27f8ed02c90977/diffrax/integrate.py#L248). (Especially given the recent JAX work on applying DCE to scans?) So...
Very cool work! It would be great to have an example on how to solve an ODE that is _directly_ forced by some signal `x(t)`, e.g. a forced mass-spring-damper `m...
When jit-compiling on the GPU: errors are currently only printed to stdout/stderr rather than properly raising. See https://github.com/google/jax/issues/9457, including the bare-bones of a possible workaround. In practice it'd be better...
When integrating large systems, it might take a while before the solution is returned. It'd be nice if there was a `verbose=False` argument that would either generate a progress bar...
At the moment: - the ODE order is specified - an SDE order is specified -- in principle for whatever the most general type of noise that solver is expecting....
The current JAX implementation of `searchsorted`, available here: https://github.com/google/jax/blob/f7df3ee9c4221a202959e67816d485c35eb98102/jax/_src/numpy/lax_numpy.py#L4219 has some flaws. - It uses a for loop to run to iteration, rather than early-stopping once the best value has...
At the moment this is using the `dense_info` from the larger step but advancing the solution using the two smaller steps: https://github.com/patrick-kidger/diffrax/blob/1bb90b4b687b497c76c59b4817e4b6f70f476262/diffrax/solver/base.py#L241 But these two values needn't line up, so...