Adan Optimizer
Interesting, I think that would be great! Thanks a lot!
Let us know if you'd like to discuss anything about the implementation as you write it. How much of the Adam code do you think can be reused?
This looks very interesting, it would be an amazing contribution!
I see there is an issue with the replicability of the pull request. It appears there is another implementation in jax for optax here, which might be worth looking at.
https://github.com/hr0nix/optax-adan
Thanks for the pointer @adam-hartshorne! Sadly when testing that implementation on my colab it gets a 3 order of magnitude larger error compared to mine (not sure where the difference comes from, the only difference I can see is the epsilon placement)
Hi, author of optax-adan here. Is it possible to share a collab where you've compared both implementations? I'd like to figure out where does the difference come from.
Oh, found the link in the pull request, so no worries.
Yep, the difference between implementations comes from epsilon placement. If I move it outside sqrt, the results are equal.
After this change is merged and released, I'll put a note in README.md of optax-adan that there is no need to use the package as adan is now implemented in optax.