optax icon indicating copy to clipboard operation
optax copied to clipboard

Problems when jitting Adafactor with inject_hyperparams.

Open mkunesch opened this issue 3 years ago • 0 comments

When wrapping optax.adafactor with optax.inject_hyperparams without specifying static_args

optax.inject_hyperparams(optax.adafactor)(learning_rate=0.1)

the init function of the resulting GradientTransformation cannot be jit compiled. The reason is that by default inject_hyperparams treats all arguments as dynamic and one of the argument has to be static to avoid a TracerError. A workaround is to specify the static argument:

optax.inject_hyperparams(optax.adafactor, static_args=("min_dim_size_to_factor",))(learning_rate=0.1)

However, this is not ideal since it requires the user to know which arguments should be static and which ones can be dynamic.

We should:

  • Add a test to check whether any other optimizers are affected.
  • Change the implementations so that all optimizers wrapped in inject_hyperparams can be jit compiled without any arguments being specified as static.

mkunesch avatar Sep 07 '22 12:09 mkunesch