[FSDP2] Context parallel
Building on top of #3585, this PR enables context parallelism together with FSDP2 (branch stems from that PR, that's why so many changes). ~~ Possibly, this will be moved to global, not fsdp local namespace, though it requires a bit more experiments from me to see what it works with. ~~ We probably want to always use this with fsdp/zero, as from my profiling it ends up in a perfect overlap and u get model sharding for free. For zero I'll need to check Ulysses then for context parallel.
This has been verified with benchmarks against torchtitan, where our implementation reaches ~equal performance and memory usage (against my reruns of torchtitan, not their reported numbers).
Managed to scale training to ~500k sequence length on 32x H100 with 8B Llama, torchtitan reports ~1m context length on this scale, but had the same results with their repo on this scale. With doubling sequence length, we need to double the GPU count, so I reasonably stopped at ~500k, as doing benchmarks at 64 GPU scale felt not too productive to me yet.
TODO:
- [x] Tests
- [ ] Proper benchmarks
- [x] Add to
accelerate config - [x] concept guide
- [x] #3585 needs to get merged
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.