flax
flax copied to clipboard
Add Optimization Cookbook
What does this PR do?
This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:
- Calculation of Exponential Moving Averages
- Optimizing only a low rank addition to certain weights (LORA)
- Using different learning rates for different parameters to implement the maximal update parameterization
- Using second order optimizers like LBFGS.
- Specifying sharding for optimization state that differs from that of parameter state
- Gradient accumulation
This is a work in progress: the guide will be much further fleshed out over time.
This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB