Incompatibility of least_squares and custom_vjp
I'm running into some trouble applying optimistix.least_squares(fn, LevenbergMarquardt(...), x0) to certain problems. From the error message below, my understanding of the root cause is that forward-mode autodiff cannot be used on jax.custom_vjp. In my case I am using diffrax to solve an ODE within fn(...), which I think might be causing the problem.
Is my basic understanding correct? Are there specific constraints / assumptions that fn(...) must follow for optimistix.least_squares to work (e.g. cannot use jax.custom_vjp)? Is there any way around this?
The error I get is:
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
The full code to reproduce the error is below. By the way I get the same problem when trying to use jaxopt.LevenbergMarquardt on this problem.
# === imports === #
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import diffrax
from diffrax import ODETerm, Dopri5, SaveAt
from tqdm import trange
import optimistix
from optimistix import LevenbergMarquardt
# === functions defining flow field and residuals === #
def geodesic_vector_field(P):
jacP = jax.jacobian(P)
def vector_field(t, state, args):
x, v = state
Pdx = jacP(x)
q1 = 0.5 * jnp.einsum("jki,j,k->i",Pdx, v, v)
q2 = jnp.einsum("ilp,l,p->i", Pdx, v, v)
dxdt = v
dvdt = jnp.linalg.solve(P(x), q1 + q2)
return (dxdt, dvdt)
return vector_field
def exponential_map(x0, v0, term, solver):
return diffrax.diffeqsolve(
term, solver, t0=0, t1=1, dt0=0.1, y0=(x0, v0),
saveat=SaveAt(t0=False, t1=True)
).ys[0].ravel()
def shooting_method_resids(x0, x1, term, solver):
return jax.jit(
lambda v0, args: (x1 - exponential_map(x0, v0, term, solver)).ravel()
)
# === try solving the boundary value problem === #
term = ODETerm(geodesic_vector_field(lambda x: jnp.eye(2)))
solver = Dopri5()
optimistix.least_squares(
shooting_method_resids(jnp.zeros(2), jnp.ones(2), term, solver),
LevenbergMarquardt(1e-3, 1e-3),
-1 * jnp.ones(2)
)
Yup, you're completely correct in your diagnosis: Diffrax has a jax.custom_vjp for the autodifferentiation through diffeqsolve, and this doesn't support forward-mode autodiff, which is what is used by optx.LevenbergMarquardt to compute its Jacobians.
We have essentially two possible fixes: offer a way for Diffrax to use forward-mode autodifferentiation, or offer a way for Optimistix to use reverse-mode.
For now I've just added the latter. in #51. Try using Optimistix from that branch and see if it solves your problem! You'll need to pass optx.least_squares(..., options=dict(jac="bwd")).
(I'd like to add better forward-mode support for Diffrax, but the best way of doing this is really dependent on JAX just adding directly support for jvp-of-custom_vjp, which I have a draft of here but still seems to be buggy, so I haven't gotten around to finishing it.)
Amazing, works as intended (at least for the simple example I've tried)!