optax icon indicating copy to clipboard operation
optax copied to clipboard

How to use rectification with other optimizers?

Open DaniyarM opened this issue 4 years ago • 4 comments

I want to use rectification from radam with adabelief optimizer, as suggested by the author of adabelief. How can I do that?

DaniyarM avatar Feb 15 '21 01:02 DaniyarM

It's on my TODO list to think about ways for the rectification to be composable with different transforms. Hopefully I can get to this in the next couple of weeks, but i am also happy to accept contributions/suggestions if someone is interested? :)

mtthss avatar Feb 15 '21 09:02 mtthss

For instance instead of having a monolytic scale_by_radam, we could have a wrapper like this:

def rectify(inner: GradientTransformation
            b1: float = 0.9,
            b2: float = 0.999,
            threshold: float = 5.0) -> GradientTransformation:
  """Rectify Adam-like updates."""

  def init_fn(params):
    return inner.init(params)

  def update_fn(updates, state, params=None):
    inner_updates, new_state = inner.update(updates, state, params)

    b2t = b2**new_state.count_inc
    ro_inf = 2./(1 - b2) - 1
    ro = ro_inf - 2 * new_state.count_inc * b2t / (1 - b2t)
    r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro))

    scaled_updates = jax.tree_map(lambda t: r*t, inner_updates)
    unbiased_momentum = _bias_correction(new_state.mu, b1, new_state.count_inc)

    rectified_updates = jax.tree_multimap(
      lambda x, y: jax.lax.select(ro >= threshold, x, y),
      scaled_updates, unbiased_momentum)

    return rectified_updates, new_state

  return GradientTransformation(init_fn, update_fn)

That you can use on any adam-like inner transform, So that the current scale_by_radam becomes:

rectify(scale_by_adam(...), ...)

But you could also combine it with other adam like updates such as scale_by_belief wdyt? do you think this would be helpful?

WARNING: didn't check my math carefully, code above is just meant to communicate the "idea".

mtthss avatar Feb 15 '21 11:02 mtthss

@mtthss I thing that it is a good idea. But could it be used in chain like that:

def radabelief(...):
    return chain(
        rectify(...),
        scale_by_belief(...))

? I'm very new with optax at now...

DaniyarM avatar Feb 15 '21 12:02 DaniyarM

The wrapped GradientTransformation would itself be a GradientTransformation. So you would first wrap the transform and then use the wrapped transform in a chain, e.g. you could do somethiing like this:

def radabelief(...):
    return chain(
        clip_by_global_norm(...),
        rectify(scale_by_belief(...), ...),
        scale(...)
        )

mtthss avatar Feb 15 '21 12:02 mtthss