Megatron-LM icon indicating copy to clipboard operation
Megatron-LM copied to clipboard

Fix: checkpoint load bug (TE)

Open okoge-kaz opened this issue 1 year ago • 1 comments

Issue

When using TransformerEngine with Megatron-LM for training, I encountered an issue where the Loss Curve would significantly change after loading a checkpoint. This problem did not occur when TransformerEngine was not utilized.

With TransformerEngine W B Chart 2_26_2024, 3_49_43 PM

Without TransformerEngine W B Chart 2_26_2024, 3_56_54 PM

Overview

Upon investigation to resolve this issue, I noticed that during load_state_dict, the strict parameter was set to False. This appears to be a measure to prevent errors during loading due to the name of keys in TransformerEngine transformer_engine.pytorch.TransformerLayer being different from those in the Megatron-LM implementation. However, bypassing the error by setting strict=False resulted in checkpoints not being loaded correctly.

To correct this, I made the necessary changes.

Result

After changes ( With TransformerEngine)

image image

okoge-kaz avatar Feb 26 '24 07:02 okoge-kaz

Marking as stale. No activity in 60 days.

github-actions[bot] avatar Apr 26 '24 18:04 github-actions[bot]