How to use rectification with other optimizers?
I want to use rectification from radam with adabelief optimizer, as suggested by the author of adabelief. How can I do that?
It's on my TODO list to think about ways for the rectification to be composable with different transforms. Hopefully I can get to this in the next couple of weeks, but i am also happy to accept contributions/suggestions if someone is interested? :)
For instance instead of having a monolytic scale_by_radam,
we could have a wrapper like this:
def rectify(inner: GradientTransformation
b1: float = 0.9,
b2: float = 0.999,
threshold: float = 5.0) -> GradientTransformation:
"""Rectify Adam-like updates."""
def init_fn(params):
return inner.init(params)
def update_fn(updates, state, params=None):
inner_updates, new_state = inner.update(updates, state, params)
b2t = b2**new_state.count_inc
ro_inf = 2./(1 - b2) - 1
ro = ro_inf - 2 * new_state.count_inc * b2t / (1 - b2t)
r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro))
scaled_updates = jax.tree_map(lambda t: r*t, inner_updates)
unbiased_momentum = _bias_correction(new_state.mu, b1, new_state.count_inc)
rectified_updates = jax.tree_multimap(
lambda x, y: jax.lax.select(ro >= threshold, x, y),
scaled_updates, unbiased_momentum)
return rectified_updates, new_state
return GradientTransformation(init_fn, update_fn)
That you can use on any adam-like inner transform,
So that the current scale_by_radam becomes:
rectify(scale_by_adam(...), ...)
But you could also combine it with other adam like updates such as scale_by_belief
wdyt? do you think this would be helpful?
WARNING: didn't check my math carefully, code above is just meant to communicate the "idea".
@mtthss I thing that it is a good idea. But could it be used in chain like that:
def radabelief(...):
return chain(
rectify(...),
scale_by_belief(...))
? I'm very new with optax at now...
The wrapped GradientTransformation would itself be a GradientTransformation.
So you would first wrap the transform and then use the wrapped transform in a chain,
e.g. you could do somethiing like this:
def radabelief(...):
return chain(
clip_by_global_norm(...),
rectify(scale_by_belief(...), ...),
scale(...)
)