DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Checkpoint combiner doesn't handle extract_state feature nn.Module

Open prabhuteja12 opened this issue 2 years ago • 0 comments

Describe the bug Torch nn.Modules have a feature enabled by set_extra_state and get_extract_states that are used to store additional buffers to state_dict. The current implementation of the get_fp32_state_dict_from_zero_checkpoint doesn't seem to handle this well.

  • The expectation is that values of a state_dict are torch.Tensors as evidenced by accessing .data_ptr() (see parse_model_states function).
  • Even if the above were not to be a problem: the combined output doesn't have '_extra_state' key in the state dict.

I see that the key _extra_state exists in the sharded state dict (inside the module), but once I run get_fp32_state_dict_from_zero_checkpoint it is not returned.

Is this intended behaviour? Can you suggest a quick fix for this?

Thank you!

prabhuteja12 avatar May 11 '23 09:05 prabhuteja12