jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

has_aux option in custom_root

Open marcociccone opened this issue 3 years ago • 10 comments

Hi!

I'm trying to solve a bilevel problem involving a neural network using implicit diff, but I have an incompatible shape in solve_normal_cg since the matrix multiplying in matvec is not square.

As explained in https://github.com/google/jaxopt/issues/183, this can be solved by giving an initialization to the solver with the correct shape, but I have no clue how to find the correct shape.

Is there an automatic way to infer the shape of the initialization? Thanks!

marcociccone avatar Jun 08 '22 18:06 marcociccone

Tagging @froystig in case he has a solution involving jax.eval_shape

marcociccone avatar Jun 08 '22 23:06 marcociccone

Does it mean that you try to implement implicit diff manually yourself? Why not use custom_root or custom_fixed_point? (https://jaxopt.github.io/stable/implicit_diff.html)

mblondel avatar Jun 09 '22 14:06 mblondel

Hi @mblondel, I've used custom_root, but I got this exception. I figured that there should be something wrong in my code because I'm not sure how I can get a non-square matrix there since matvec should be the vjp of my optimal function if I understood correctly the code. Is that correct?

I'll try to write a minimum example since my code is quite messy at the moment.

marcociccone avatar Jun 09 '22 14:06 marcociccone

It probably means that your optimality_fun is incorrect. If output = optimality_fun(params, *args, **kwargs), params and output should have the same shape. For instance, this is correct:

def fun(params, X, y, l2reg):
   residuals = jnp.dot(X, params) - y
   return jnp.sum(residuals ** 2) + l2reg * jnp.sum(params ** 2)

optimality_fun = jax.grad(fun, argnums=0)

because the gradient operator is a function from R^d to R^d. Hence the linear operator within implicit diff will be square. You can also use output, aux = optimality_fun(params, *args, **kwargs) but in this case you need to set has_aux=True in custom_root.

mblondel avatar Jun 09 '22 14:06 mblondel

I am able to reproduce the same error with this code.

Do you see something wrong with the optimality function as it is? The unrolling works correctly, but evidently I'm missing something. Thanks

Traceback (most recent call last): File "/home/marco/exp/jaxopt/jaxopt/_src/linear_solve.py", line 180, in solve_normal_cg rmatvec = _make_rmatvec(matvec, example_x) File "/home/marco/exp/jaxopt/jaxopt/_src/linear_solve.py", line 144, in _make_rmatvec transpose = jax.linear_transpose(matvec, x) File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/_src/api.py", line 2518, in linear_transpose jaxpr, out_pvals, const = pe.trace_to_jaxpr_nounits(flat_fun, in_pvals, File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper return func(*args, **kwargs) File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 608, in trace_to_jaxpr_nounits jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/home/marco/exp/jaxopt/jaxopt/_src/implicit_diff.py", line 63, in matvec = lambda u: vjp_fun_sol(u)[0] File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/_src/tree_util.py", line 287, in call return self.fun(*args, **kw) File "/home/marco/.miniconda3/lib/python3.9/site-packages/jax/_src/api.py", line 2336, in _vjp_pullback_wrapper raise TypeError(f"Tree structure of cotangent input {in_tree}, does not match structure of " TypeError: Tree structure of cotangent input PyTreeDef(CustomNode(<class 'flax.core.frozen_dict.FrozenDict'>[()], [{'BatchNorm_0': {'bias': *, 'scale': *}, 'BatchNorm_1': {'bias': *, 'scale': *}, 'Conv_0': {'bias': *, 'kernel': *}, 'Conv_1': {'bias': *, 'kernel': *}, 'Dense_0': {'bias': *, 'kernel': *}, 'Dense_1': {'bias': *, 'kernel': *}}])), does not match structure of primal output PyTreeDef((CustomNode(<class 'flax.core.frozen_dict.FrozenDict'>[()], [{'BatchNorm_0': {'bias': *, 'scale': *}, 'BatchNorm_1': {'bias': *, 'scale': *}, 'Conv_0': {'bias': *, 'kernel': *}, 'Conv_1': {'bias': *, 'kernel': *}, 'Dense_0': {'bias': *, 'kernel': *}, 'Dense_1': {'bias': *, 'kernel': *}}]), CustomNode(<class 'flax.core.frozen_dict.FrozenDict'>[()], [{'BatchNorm_0': {'mean': *, 'var': *}, 'BatchNorm_1': {'mean': *, 'var': *}}]))). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/home/marco/exp/jaxopt/examples/implicit_diff/nn_ho.py", line 223, in l2reg, outer_state = gd_outer.update( File "/home/marco/exp/jaxopt/jaxopt/_src/optax_wrapper.py", line 122, in update (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) File "/home/marco/exp/jaxopt/examples/implicit_diff/nn_ho.py", line 196, in outer_loss inner_sol, inner_state = inner_loop_solver( File "/home/marco/exp/jaxopt/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals) jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: The initialization init of solve_normal_cg is compulsory when matvec is nonsquare. It should have the same pytree structure as a solution. Typically, a pytree filled with zeros should work.

marcociccone avatar Jun 09 '22 20:06 marcociccone

I think I figured out what the problem is. It seems is not enough to have has_aux=True in custom_root(which seems to be needed if the inner solver has aux).

Indeed, I think that the vjp here needs has_oux=True too, hence the exception.

I solved it by wrapping the optimality function to return only the correct output for now. jax.grad(lambda *args: loss_fun(*args)[0])

Is this the intended behavior? I don't see the point of having an aux for the optimality function here, but it may be useful in other situations.

marcociccone avatar Jun 09 '22 22:06 marcociccone

custom_root with has_aux is tested here: https://github.com/google/jaxopt/blob/main/tests/implicit_diff_test.py#L87

If you can reproduce your issue with a smaller script and using toy data, that would be useful.

mblondel avatar Jun 10 '22 14:06 mblondel

I think I figured what's the confusion is. has_aux in custom_root concerns the decorated solver (https://jaxopt.github.io/stable/_autosummary/jaxopt.implicit_diff.custom_root.html#jaxopt.implicit_diff.custom_root), not optimality_fun, as visible from https://github.com/google/jaxopt/blob/main/tests/implicit_diff_test.py#L87.

So at the moment you need to do:

def optimality_fun(params, ...):
  return jax.grad(fun, has_aux=True)(params, ...)[0]

@custom_root(optimality_fun)
def solver(init_params, ...):
   ...
  return sol

Maybe in the future we can add the options optimality_fun_has_aux and solver_has_aux to be more explicit... @froystig

mblondel avatar Jun 10 '22 14:06 mblondel

Thanks @mblondel, that's exactly it. I was able to make my code working.

marcociccone avatar Jun 12 '22 19:06 marcociccone

Introducing explicit options sounds good to me!

froystig avatar Jun 15 '22 17:06 froystig