optax icon indicating copy to clipboard operation
optax copied to clipboard

Muon for 4D parameters

Open ZagButNoZig opened this issue 9 months ago • 5 comments

The reference implementation of Muon supports 4D parameters (like conv filters) through flattening here.

As far as I can see the current optax Implementation does not, because of the partitioning function.

ZagButNoZig avatar Apr 18 '25 11:04 ZagButNoZig

Thanks for pointing this out! Would you be willing to contribute and fix this?

vroulet avatar Apr 18 '25 15:04 vroulet

I can make a PR, but only in a few days if that's ok

ZagButNoZig avatar Apr 18 '25 16:04 ZagButNoZig

Totally fine, we really appreciate the effort!

vroulet avatar Apr 18 '25 16:04 vroulet

@vroulet I also noticed the implementation completely ignores this warning. Should we update the docs accordingly?

ZagButNoZig avatar Apr 18 '25 19:04 ZagButNoZig

Yes, that would be great!

vroulet avatar Apr 18 '25 19:04 vroulet

This should now be supported after the recent series of changes:

  • https://github.com/google-deepmind/optax/pull/1435
  • https://github.com/google-deepmind/optax/pull/1407

rdyro avatar Oct 24 '25 16:10 rdyro