maxtext
maxtext copied to clipboard
Set param_scan_axis=0 and change all checkpoint creation files
Description
Setting param_scan_axis=0 helps in improving perf and reducing memory required by optimizer state.
FIXES: #1382
Tests
Since this is touching all the checkpoint creation files, I'm going to ask help for creating new checkpoints and comparing logits.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [ ] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed.