axlearn
axlearn copied to clipboard
Gradient Accumulation in Axlearn
Gradient accumulation allows training with higher batch sizes without scaling out.
Added a new learner type learner.klass: 'axlearn.common.learner.AccumulatedLearner'
At a high level the optimization does the following:
- Input batch is split into even microbatches.
- Creates a buffer for gradients and metrics.
- Runs forward and backward pass for each microbatch in a loop summing up the gradients and aggregating metrics.
- Average gradients across microbatches and normalize metrics.
Configuration changes:
- Number of microbatches are specified during configuration through option
accumulation_microbatchesin the trainer andmicriobatchesin the learner.