docs: fixed the `init_module` and deepspeed
What does this PR do?
Fixes the init_module docs that suggest to do the following:
# Recommended for FSDP, TP and DeepSpeed
with fabric.init_module(empty_init=True):
model = GPT3() # parameters are placed on the meta-device
model = fabric.setup(model) # parameters get sharded and initialized at once
# Make sure to create the optimizer only after the model has been set up
optimizer = torch.optim.Adam(model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
However, setting up the optimizer and the model separately breaks with DeepSpeed because of the following: https://github.com/Lightning-AI/pytorch-lightning/blob/cf24a190ce52d65ae7316fceec584145f1e1f006/src/lightning/fabric/fabric.py#L1031-L1037
Changed the docs to reflect the correct syntax.
Furthermore, discussed the necessity to use this with DeepSpeed Stage 3 as is stated here: https://github.com/Lightning-AI/pytorch-lightning/issues/17792#issuecomment-1641144583
~~I only discussed DeepSpeed Stage 3 as I cannot find a statement whether this needs to be done for Stage 2 to work correctly. From personal experience - this led to an improvement by allowing a slightly larger batch (2 -> 4) for me. Personally confused as to why this helps since Stage 2 is not sharding the parameters, but my understanding of these strategies is highly limited. Potentially @awaelchli or someone else could enlighten me if/why it does? Happy to change to include Stage 2 (or also 1?).~~
Only discussed for Stage 3 as that is where this is necessary. As an update - I believe i traced down what improved my performance to a larger batch and it wasn't the inclusion of the init module.
Just as a note - I read the Docs editing README and that building locally is required. Had some local device issues with building them (my local device issues) but I triple checked that it follows the .rst format, and this is a simple change.
Fixes: No issue as this a (likely straightforward) documentation fix.
Before submitting
- Was this discussed/agreed via a GitHub issue? (not for typos and docs)
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
- Did you make sure to update the documentation with your changes? (if necessary)
- Did you write any new necessary tests? (not for typos and docs)
- [ ] Did you verify new and existing tests pass locally with your changes?
- Did you list all the breaking changes introduced by this pull request?
- Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)
PR review
Anyone in the community is welcome to review the PR. Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
- [ ] Is this pull request ready for review? (if not, please submit in draft mode)
- [ ] Check that all items from Before submitting are resolved
- [ ] Make sure the title is self-explanatory and the description concisely explains the PR
- [ ] Add labels and milestones (and optionally projects) to the PR so it can be classified
Did you have fun?
Make sure you had fun coding 🙃
[x] Yes! Debugging why DeepSpeed wasn't working wasn't as fun though. Potentially could benefit from an article in docs, even though this is an experimental feature in fabric?
📚 Documentation preview 📚: https://pytorch-lightning--20175.org.readthedocs.build/en/20175/
Just as a note - I myself was able to test that this is the correct code to .setup with Stage 2 only ~~(where it helped)~~. I cannot run Stage 3 due to the old ~~non-deepspeed~~ non-fabric checkpoints issue.
@alyakin314 could you pls update the PR based on Adrian comments?
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://lightning.ai/docs/pytorch/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Discord. Thank you for your contributions.