Add signSGD optimizer
Huge fan of optax and the jax family of libraries! I've started using signSGD quite frequently in my research, and noticed there doesn't seem to be an existing implementation of the optimizer, nor is there an easy way to define it by chaining transformations. Given the rising importance of sign-based methods (the original signSGD paper has >1k citations) and its prevalence in modern optimizers (e.g. RMSProp, Adam, and Lion are all sign-based), it seems like a good idea to build more explicit support for sign-based methods in optax. To this end, I've implemented two simple additions:
-
scale_by_sign, which simply computes the signs on the inputs gradients -
sign_sgd, which is a vanilla implementation of the signSGD algorithm
I hope this PR makes it easier to experiment with sign-based methods, and define new sign-based optimizers in optax! Looking forward to your comments.
Great, thanks for the notes! Made the suggested changes. Let me know if there's anything additional
I believe this has now been merged. Thanks @wtong98 for the contribution!