DeepSpeed
DeepSpeed copied to clipboard
[BUG] Checkpoint combiner doesn't handle extract_state feature nn.Module
Describe the bug
Torch nn.Modules have a feature enabled by set_extra_state and get_extract_states that are used to store additional buffers to state_dict. The current implementation of the get_fp32_state_dict_from_zero_checkpoint doesn't seem to handle this well.
- The expectation is that values of a state_dict are
torch.Tensors as evidenced by accessing.data_ptr()(seeparse_model_statesfunction). - Even if the above were not to be a problem: the combined output doesn't have '_extra_state' key in the state dict.
I see that the key _extra_state exists in the sharded state dict (inside the module), but once I run get_fp32_state_dict_from_zero_checkpoint it is not returned.
Is this intended behaviour? Can you suggest a quick fix for this?
Thank you!