Fixes for training models with bf16 + freshly initialized optimizer via `load_module_only`
This PR makes some fixes to the case where we want to resume training from a DeepSpeed ZeRO checkpoint and initialize a new optimizer, while not using the old optimizer in the checkpoint or relying on its existence at all.
in this situation, despite passing load_module_only=True and load_optimizer_states=False to load_checkpoint(), the previous behavior was that:
-
self._load_zero_checkpointwould still be called, which attempts to load from the (in this case, nonexistent) checkpoint files. This PR stops this function from being called if usingload_module_only=Trueandload_optimizer_states=False. Alternatively, calling this function may be alright if"load_from_fp32_weights": trueis set in the DeepSpeed ZeRO config (reference: https://github.com/microsoft/DeepSpeed/blob/ff7d5275f2aa916cb5f320e0d817154e96f9cdb6/deepspeed/runtime/engine.py#L733) but this parameter does not seem to be documented in the docs for ZeRO config dicts. - in
_load_checkpoint, the following codeblock:
if self.optimizer is not None and self.fp16_enabled():
self.optimizer.refresh_fp32_params()
results in self.optimizer.refresh_fp32_params() being called only if using FP16. As a result, the FP32 optimizer state is never initialized from the 16-bit model weights. This PR removes the fp16-specific condition.
Previously reported in: https://github.com/EleutherAI/gpt-neox/issues/947 https://github.com/EleutherAI/gpt-neox/issues/843
Should also close: https://github.com/microsoft/DeepSpeed/issues/4017
This caused problems for a freshly-converted LLama checkpoint, which did not contain optimizer states, when trying to train with this model as initialization. I have confirmed the following fixes prevent this behavior.
cc @Quentin-Anthony @zhangir-azerbayev
My model was trained in bf16 mode ,when loading ckpt with load_optimizer_states = False, it still trys to load it . I avoid that by the following :
engine._config.bfloat16_enabled = False
_,ckpt_config=engine.load_checkpoint("check",load_module_only=True,load_optimizer_states= False)
engine._config.bfloat16_enabled = True
engine.optimizer._restore_from_bit16_weights()
@tjruwase and @jeffra -- Want any more detail or testing from our side? This fix resolved a lot of issues on our end, and we suspect non-neox users may face it too.
@haileyschoelkopf and @Quentin-Anthony, apologies for dropping the ball on this. The PR LGTM. Please resolve the conflict so we can merge. Thanks for the contribution!
Thanks @tjruwase , merge conflicts resolved!
How come this PR was never merged? It fixed the finetuning bug I experienced myself, which took about 3-4 days...
Specifically, removing the extra condition to check for fp16_enabled is not needed (and hurts bf16 for example).
before:
if self.optimizer is not None and self.fp16_enabled():
after:
if self.optimizer is not None:
This should do the trick.