Optimize GPU checkpoint loading by ensuring model transfer before load_state_dict on build method
While reviewing the code in this repository, I noticed a few areas that could be optimized for efficiency. I decided to make some changes to how the models are loaded onto the GPU before applying their checkpoints. I believe this should have a positive impact on the performance and overall behavior of the code.
Thanks to everyone who contributed to this repo—really appreciate all the hard work that went into it
Summary
This PR ensures that both prefill_model and decode_model are moved to the
target device (e.g., GPU) before invoking load_state_dict.
Motivation
Previously, if the models were still on CPU when loading checkpoints, PyTorch
would perform an additional transfer of tensors, causing unnecessary overhead.
By explicitly moving the models to the correct device first, we avoid redundant
transfers and improve checkpoint loading efficiency.
Changes
- Added
prefill_model.to(device)before loading its state dict. - Added
decode_model.to(device)before loading its state dict.
Impact
This reduces unnecessary GPU/CPU transfers during checkpoint loading, which should result in faster and more efficient model initialization.
Let's go it' really good 👍
can you review this PR? @lfoppiano @jasondavies @corygehr
@nimanikoo Can I ask where you got my name to review this PR? I don't own this repository and a cursory glance doesn't yield any results implying I should be a required reviewer.
I'm curious because this is not the only repository where this keeps happening. I do have (limited) oversight, but not to the degree where I should be approving PRs.
Thanks.
Thanks for the heads-up!
I mentioned you because my PR has been pending for a few months without any review,
and since you’re listed as a member/contributor, I thought you might be someone who could take a look
or redirect it to the right reviewer.
If you’re not the right person, no worries at all — just wanted to get the PR unblocked.
@corygehr