Muon for 4D parameters
The reference implementation of Muon supports 4D parameters (like conv filters) through flattening here.
As far as I can see the current optax Implementation does not, because of the partitioning function.
Thanks for pointing this out! Would you be willing to contribute and fix this?
I can make a PR, but only in a few days if that's ok
Totally fine, we really appreciate the effort!
@vroulet I also noticed the implementation completely ignores this warning. Should we update the docs accordingly?
Yes, that would be great!
This should now be supported after the recent series of changes:
- https://github.com/google-deepmind/optax/pull/1435
- https://github.com/google-deepmind/optax/pull/1407