Fix: checkpoint load bug (TE)
Issue
When using TransformerEngine with Megatron-LM for training, I encountered an issue where the Loss Curve would significantly change after loading a checkpoint. This problem did not occur when TransformerEngine was not utilized.
With TransformerEngine
Without TransformerEngine
Overview
Upon investigation to resolve this issue, I noticed that during load_state_dict, the strict parameter was set to False. This appears to be a measure to prevent errors during loading due to the name of keys in TransformerEngine transformer_engine.pytorch.TransformerLayer being different from those in the Megatron-LM implementation. However, bypassing the error by setting strict=False resulted in checkpoints not being loaded correctly.
To correct this, I made the necessary changes.
Result
After changes ( With TransformerEngine)
Marking as stale. No activity in 60 days.