optax
optax copied to clipboard
Problems when jitting Adafactor with inject_hyperparams.
When wrapping optax.adafactor with optax.inject_hyperparams without specifying static_args
optax.inject_hyperparams(optax.adafactor)(learning_rate=0.1)
the init function of the resulting GradientTransformation cannot be jit compiled. The reason is that by default inject_hyperparams treats all arguments as dynamic and one of the argument has to be static to avoid a TracerError. A workaround is to specify the static argument:
optax.inject_hyperparams(optax.adafactor, static_args=("min_dim_size_to_factor",))(learning_rate=0.1)
However, this is not ideal since it requires the user to know which arguments should be static and which ones can be dynamic.
We should:
- Add a test to check whether any other optimizers are affected.
- Change the implementations so that all optimizers wrapped in inject_hyperparams can be jit compiled without any arguments being specified as static.