flax icon indicating copy to clipboard operation
flax copied to clipboard

Add Optimization Cookbook

Open samanklesaria opened this issue 2 months ago • 1 comments

What does this PR do?

This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:

  • Calculation of Exponential Moving Averages
  • Optimizing only a low rank addition to certain weights (LORA)
  • Using different learning rates for different parameters to implement the maximal update parameterization
  • Using second order optimizers like LBFGS.
  • Specifying sharding for optimization state that differs from that of parameter state
  • Gradient accumulation

This is a work in progress: the guide will be much further fleshed out over time.

This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.

samanklesaria avatar Nov 28 '25 20:11 samanklesaria

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB