Speed up model loading for generate
This has not been extensively tested (only mistral 7b) and more of a proposal!
This change does the follow:
- Create the model on the meta device
- Load the state dict with assign=True which preserve the properties of the checkpoint (mmap-ed cpu Tensor in this case)
- Initialize non-persistent buffers remaining on the meta device
- Move the finalized model to the requested device/dtype
This makes the model loading almost instant on my machine.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/709
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:x: 1 New Failure
As of commit add24c6ad1f193a30196453c2a02c01f77dd3053 with merge base ada52240514bb9fa07f91a50ad0a31063f13834c ():
NEW FAILURE - The following job has failed:
-
Lint / lint (3.10) (gh)
Process completed with exit code 1.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Any way I can know from the CI logs what is my lint mistake so I can fix it?
@albanD thanks so much for putting this up! I'll take a more detailed look tomorrow, but to answer your lint question - you can do the following:
pre-commit install
pre-commit run --all-files
This will fix all of the issues for you.
Thanks for the PR @albanD! Tbh we have already had a fraught relationship with meta device initialization 😅 (see e.g. #317, #418, #514). Our latest status is that we deliberately sacrifice a bit on time-to-first-batch for the sake of keeping code in the model components agnostic to meta device. But generation is an interesting case since the total runtime is much lower than on a finetune with FSDP (which is what we were focusing on previously). Out of curiosity, what is the speedup of meta device vs just initializing directly on GPU in this case?
I would need to check once I go back on the machine in question. The more important bit tbh is that the CPU model was fully using the mmap-ed loaded Tensors and so was not blowing up my scarse RAM :D
@kartikayk I saw that but I don't have pre-commit in my environment :p