has_aux option in custom_root
Hi!
I'm trying to solve a bilevel problem involving a neural network using implicit diff, but I have an incompatible shape in solve_normal_cg since the matrix multiplying in matvec is not square.
As explained in https://github.com/google/jaxopt/issues/183, this can be solved by giving an initialization to the solver with the correct shape, but I have no clue how to find the correct shape.
Is there an automatic way to infer the shape of the initialization? Thanks!
Tagging @froystig in case he has a solution involving jax.eval_shape
Does it mean that you try to implement implicit diff manually yourself? Why not use custom_root or custom_fixed_point? (https://jaxopt.github.io/stable/implicit_diff.html)
Hi @mblondel, I've used custom_root, but I got this exception.
I figured that there should be something wrong in my code because I'm not sure how I can get a non-square matrix there since matvec should be the vjp of my optimal function if I understood correctly the code. Is that correct?
I'll try to write a minimum example since my code is quite messy at the moment.
It probably means that your optimality_fun is incorrect. If output = optimality_fun(params, *args, **kwargs), params and output should have the same shape. For instance, this is correct:
def fun(params, X, y, l2reg):
residuals = jnp.dot(X, params) - y
return jnp.sum(residuals ** 2) + l2reg * jnp.sum(params ** 2)
optimality_fun = jax.grad(fun, argnums=0)
because the gradient operator is a function from R^d to R^d. Hence the linear operator within implicit diff will be square. You can also use output, aux = optimality_fun(params, *args, **kwargs) but in this case you need to set has_aux=True in custom_root.
I am able to reproduce the same error with this code.
Do you see something wrong with the optimality function as it is? The unrolling works correctly, but evidently I'm missing something. Thanks
Traceback (most recent call last): File "/home/marco/exp/jaxopt/jaxopt/_src/linear_solve.py", line 180, in solve_normal_cg rmatvec = _make_rmatvec(matvec, example_x) File "/home/marco/exp/jaxopt/jaxopt/_src/linear_solve.py", line 144, in _make_rmatvec transpose = jax.linear_transpose(matvec, x) File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/_src/api.py", line 2518, in linear_transpose jaxpr, out_pvals, const = pe.trace_to_jaxpr_nounits(flat_fun, in_pvals, File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 608, in trace_to_jaxpr_nounits jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/home/marco/exp/jaxopt/jaxopt/_src/implicit_diff.py", line 63, in
matvec = lambda u: vjp_fun_sol(u)[0] File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/_src/tree_util.py", line 287, in call return self.fun(*args, **kw) File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/_src/api.py", line 2336, in _vjp_pullback_wrapper raise TypeError(f"Tree structure of cotangent input {in_tree}, does not match structure of " TypeError: Tree structure of cotangent input PyTreeDef(CustomNode(<class 'flax.core.frozen_dict.FrozenDict'>[()], [{'BatchNorm_0': {'bias': *, 'scale': *}, 'BatchNorm_1': {'bias': *, 'scale': *}, 'Conv_0': {'bias': *, 'kernel': *}, 'Conv_1': {'bias': *, 'kernel': *}, 'Dense_0': {'bias': *, 'kernel': *}, 'Dense_1': {'bias': *, 'kernel': *}}])), does not match structure of primal output PyTreeDef((CustomNode(<class 'flax.core.frozen_dict.FrozenDict'>[()], [{'BatchNorm_0': {'bias': *, 'scale': *}, 'BatchNorm_1': {'bias': *, 'scale': *}, 'Conv_0': {'bias': *, 'kernel': *}, 'Conv_1': {'bias': *, 'kernel': *}, 'Dense_0': {'bias': *, 'kernel': *}, 'Dense_1': {'bias': *, 'kernel': *}}]), CustomNode(<class 'flax.core.frozen_dict.FrozenDict'>[()], [{'BatchNorm_0': {'mean': *, 'var': *}, 'BatchNorm_1': {'mean': *, 'var': *}}]))). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/marco/exp/jaxopt/examples/implicit_diff/nn_ho.py", line 223, in l2reg, outer_state = gd_outer.update( File "/home/marco/exp/jaxopt/jaxopt/_src/optax_wrapper.py", line 122, in update (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) File "/home/marco/exp/jaxopt/examples/implicit_diff/nn_ho.py", line 196, in outer_loss inner_sol, inner_state = inner_loop_solver( File "/home/marco/exp/jaxopt/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals) jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: The initialization initof solve_normal_cg is compulsory whenmatvecis nonsquare. It should have the same pytree structure as a solution. Typically, a pytree filled with zeros should work.
I think I figured out what the problem is.
It seems is not enough to have has_aux=True in custom_root(which seems to be needed if the inner solver has aux).
Indeed, I think that the vjp here needs has_oux=True too, hence the exception.
I solved it by wrapping the optimality function to return only the correct output for now.
jax.grad(lambda *args: loss_fun(*args)[0])
Is this the intended behavior? I don't see the point of having an aux for the optimality function here, but it may be useful in other situations.
custom_root with has_aux is tested here: https://github.com/google/jaxopt/blob/main/tests/implicit_diff_test.py#L87
If you can reproduce your issue with a smaller script and using toy data, that would be useful.
I think I figured what's the confusion is. has_aux in custom_root concerns the decorated solver (https://jaxopt.github.io/stable/_autosummary/jaxopt.implicit_diff.custom_root.html#jaxopt.implicit_diff.custom_root), not optimality_fun, as visible from https://github.com/google/jaxopt/blob/main/tests/implicit_diff_test.py#L87.
So at the moment you need to do:
def optimality_fun(params, ...):
return jax.grad(fun, has_aux=True)(params, ...)[0]
@custom_root(optimality_fun)
def solver(init_params, ...):
...
return sol
Maybe in the future we can add the options optimality_fun_has_aux and solver_has_aux to be more explicit... @froystig
Thanks @mblondel, that's exactly it. I was able to make my code working.
Introducing explicit options sounds good to me!