Differentiating auxiliary variables
Hello @patrick-kidger, thank you very much for all your amazing libraries! I did not manage to find an issue related to this but, currently we can only differentiate through the solution of e.g. a root-finding problem but not its auxiliary data. Consider this example
import jax
import jax.numpy as jnp
import optimistix as optx
def compute_sqrt(x):
def fn(y, args):
def expensive_function(y):
return y**2
return expensive_function(y) - args, expensive_function(y)
solver = optx.Newton(rtol=1e-5, atol=1e-5)
y0 = jnp.array(1.0)
sol = optx.root_find(fn, solver, y0, x, has_aux=True)
sqrt_x = sol.value
x_ = sol.aux
return sqrt_x, x_
x = 2.0
print(compute_sqrt(x))
>>> (Array(1.4142135, dtype=float32), Array(2.000006, dtype=float32))
print(jax.jacobian(compute_sqrt)(x))
>>> (Array(0.35355338, dtype=float32), Array(0., dtype=float32))
where the second output of jax.jacobian is zero because auxiliary data is somehow considered as a fixed quantity.
Assuming that my residual function involves some expensive calculation, I would like to avoid the need of reevaluating x_ = expensive_function(sqrt_x) to get its gradient.
Is there a cleverer way to do this ?
Hi!
and unfortunately, no. This is by design and follows the way JAX treats functions with auxiliary variables. If auxiliary data is being computed, derivatives are only taken with respect to the first argument being returned: https://docs.jax.dev/en/latest/_autosummary/jax.linearize.html This means that we never compute a Jacobian with respect to auxiliary variables.
The important consequence of this is that it is clear with respect to what parameters should be optimised, in any library (not just Optimistix). The derivative of fn with respect to y is what the solver uses to find the root, what SGD uses to find the minimum, and so on.
If you wanted to, you could implement your own wrapper for the root finders that computes the Jacobian of aux with respect to y, e.g. in the post processing step, and returns it as part of the wrapped solver state. The extra computational expense would still be there, though.
I'll add that the computational expense is unfortunately pretty much inevitable. This is because the point that Optimistix returns may not actually be a point that it has evaluated during its solve. For a common machine learning reference point on this, consider running SGD, in which we typically update our parameters and then halt... without a final re-evaluation of our loss on those final updated parameters.
What do we mean by "evaluate" here? We do check for convergence based on the function value and the value of the optimisation variable. Whether we return the newest trial iterate or the last accepted iterate depends on the solver (and more generally on whether it uses a Search or not). The root finders (e.g. optx.Newton) return the newest trial iterate, but the quasi-Newton and Least-Squares solvers, as well as the gradient descent solvers (but not the OptaxMinimiser) return the last accepted iterate, rather than the trial iterate computed from this point.
The difference should be negligible (within the specified tolerance), provided that we have converged to a (local) minimum.
Whether we return the newest trial iterate or the last accepted iterate depends on the solver
Exactly, whether this is the case or not is allowed to be solver-dependent. So it's not a guarantee that Optx offers as part of its API.
(As far as I can tell this is a detail that only matters for returning aux, otherwise it's invisible to the user.)