optax icon indicating copy to clipboard operation
optax copied to clipboard

Add signSGD optimizer

Open wtong98 opened this issue 1 year ago • 1 comments

Huge fan of optax and the jax family of libraries! I've started using signSGD quite frequently in my research, and noticed there doesn't seem to be an existing implementation of the optimizer, nor is there an easy way to define it by chaining transformations. Given the rising importance of sign-based methods (the original signSGD paper has >1k citations) and its prevalence in modern optimizers (e.g. RMSProp, Adam, and Lion are all sign-based), it seems like a good idea to build more explicit support for sign-based methods in optax. To this end, I've implemented two simple additions:

  1. scale_by_sign, which simply computes the signs on the inputs gradients
  2. sign_sgd, which is a vanilla implementation of the signSGD algorithm

I hope this PR makes it easier to experiment with sign-based methods, and define new sign-based optimizers in optax! Looking forward to your comments.

wtong98 avatar Jun 25 '24 21:06 wtong98

Great, thanks for the notes! Made the suggested changes. Let me know if there's anything additional

wtong98 avatar Jun 26 '24 18:06 wtong98

I believe this has now been merged. Thanks @wtong98 for the contribution!

fabianp avatar Sep 02 '24 17:09 fabianp