mup icon indicating copy to clipboard operation
mup copied to clipboard

integration with Flax?

Open nestordemeure opened this issue 3 years ago • 4 comments

Is there any interest in integrating this work with Flax?

They already have a init function, decoupling parameters initialization from model definition which could make introducing mup fairly plug-and-play.

Plus they relie on optax for their optimizers. As that library has a focus on composability, you might be able to introduce a transformation that takes an optimizer and makes it mup compatible.

Overall, I believe the Flax ecosystem could make mup more easily accessible to people.

nestordemeure avatar Apr 13 '22 21:04 nestordemeure

Integration with Flax would be fantastic, but neither I nor @edwardjhu are familiar with it. If someone from the Flax team can work with us, we can definitely advise the integration process.

thegregyang avatar Apr 15 '22 19:04 thegregyang

@nestordemeure In case you're interested, I have a first draft of a port to JAX/Haiku here. If you're not attached to FLAX in particular you could use this. You could also probably adapt this design to FLAX if you wanted, since FLAX/Haiku are more similar than FLAX/torch.

Edit: @thegregyang By the way, can you take a look at the plots in the README there? The optimal learning rate stabilizes with width, but it does look like I see better training loss for SP sometimes. Is that indicative of a bug? My coord checks look good, nothing grows with width, output norm (at init) decays with width.

davisyoshida avatar Apr 25 '22 18:04 davisyoshida

Hey @davisyoshida your repo looks great so far!

For your plot, you'd get better results if you tune the input, output, and hidden learning rates for your small model and scale up from there, sweeping a global lr multiplier on the x-axis (ideally, you tune (lr, init) for all parameter tensors, but these 3 learning rates should be a good practical approximation). In particular, for a fair comparison, the curves for your small model in both SP and muP plots should be the same. Your current plots are just looking at a slice of the HP space (of (lr, init) for all parameter tensors) away from the true optimum.

thegregyang avatar Apr 27 '22 12:04 thegregyang

Ah that makes perfect sense, I'll generate new versions of the figures. Thanks!

davisyoshida avatar Apr 27 '22 17:04 davisyoshida