optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Including user-defined Jacobian

Open Justin-Tan opened this issue 2 years ago • 4 comments

Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports vmap for some time. Currently I am using an external call to scipy.optimize.root together with the multiprocessing library, which is quite slow.

The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?

For reference, this is my MWE in case I am not doing things efficiently:

from jax import jit, jacfwd, vmap, random
import optimistix as optx

def fn(y, b):
    return (y-b)**2

M = 1024
key = random.PRNGKey(42)
key, key_ = random.split(key, 2)

y = random.normal(key, (M,))
b = random.normal(key_, (M,))
sol = optx.root_find(vmap(fn), solver, y, b)

Justin-Tan avatar Oct 19 '23 23:10 Justin-Tan

Okay, many things to respond to here!

Speed

With respect to the speed, for your JAX code are you:

  • JIT'ing everything;
  • excluding compile time;
  • including block_until_ready?

In practice this means writing things out something like:

@jax.jit
def run(y, b):
    sol = optax.root_find(vmap(fn), solver, y, b)
    return sol.value

run(y, b)  # compile
times = timeit.repeat(lambda: jax.block_until_ready(run(y, b)), number=1, repeat=10)
print(min(times))

Recalculating Jacobians

You commented on calculating the Jacobian afresh every iteration. If using the typical Newton algorithm then this is expected (desired) behaviour. But if you're saying that you'd prefer to use a quasi-Newton algorithm like the chord method (that computes the Jacobian once at the initial point and then re-uses it), then there is optx.Chord as well.

Analytical Jacobians

You commented on supplying an analytical Jacobian. This isn't necessary, as the analytical Jacobian is actually already derived from fn automatically using autodifferentiation. Unless the autodiff does something surprisingly inefficient, then providing one manually wouldn't meaningfully improve things there.

Custom Jacobians

If despite everything you really do want to provide a custom Jacobian, then this can be done using jax.custom_jvp. By wrapping your fn in a jax.custom_vjp, then you can override how JAX calculates autoderivatives of your code. (And this will then be picked up by the autodiff used by Optimistix to calculate the Jacobian.)

Does the above help?

patrick-kidger avatar Oct 20 '23 00:10 patrick-kidger