Megatron-LM icon indicating copy to clipboard operation
Megatron-LM copied to clipboard

[QUESTION] checkpointing/loading memory overhead

Open JinjieNi opened this issue 1 year ago • 4 comments

It seems that in current implementation the torch_dist checkpointing and loading will introduce around 2GB GPU memory overhead for rank 0 (for a 400m model), which will cause OOM if the original GPU untilization is high.

Is there a way to free this GPU memory allocation after successfully checkpointing/loading? so that such overhead will not impact the later training process, otherwise it's quite annoying as it requires an additional test (and optimization if OOM) for checkpointing and loading.

(tried with adding torch.cuda.empty_cache() in the checkpointing function but not working. And the load_checkpoint function alr has a "torch.cuda.empty_cache()" in the end)

Thank you!

JinjieNi avatar Feb 06 '25 03:02 JinjieNi

+1

Ethan-yt avatar Mar 21 '25 07:03 Ethan-yt

I found when apply merge functions, torch.cat() introduce extra memory. The tensors to be cat are sharing the same underlying storage, which means there is no need to cat again.

Try this:

change sh_ten_merge_fn https://github.com/NVIDIA/Megatron-LM/blob/11996c9fd1a2d0aaef6dafc1fd4219aa795f188f/megatron/core/transformer/mlp.py#L261 to:

def memory_saving_sh_ten_merge_fn(sub_state_dict):
    with torch.no_grad():
        shared_storage = sub_state_dict[0].untyped_storage()
        if all(shared_storage.data_ptr() == tensor.untyped_storage().data_ptr() for tensor in sub_state_dict):
            element_size = sub_state_dict[0].element_size()
            total_numel = sum(tensor.numel() for tensor in sub_state_dict)
            if shared_storage.nbytes() == total_numel * element_size:
                dim_0 = sum(tensor.shape[0] for tensor in sub_state_dict)
                shape = (dim_0,) + sub_state_dict[0].shape[1:]
                combined_tensor = torch.empty(
                    shape, dtype=sub_state_dict[0].dtype, device=sub_state_dict[0].device
                ).set_(shared_storage, 0, shape)
                return combined_tensor
        return torch.cat(sub_state_dict)

Ethan-yt avatar Mar 21 '25 08:03 Ethan-yt

@ko3n1g If this code works, could you please add me to the coauthor of Megatron-LM? thanks!

Ethan-yt avatar Mar 21 '25 09:03 Ethan-yt

Marking as stale. No activity in 60 days.

github-actions[bot] avatar May 23 '25 18:05 github-actions[bot]

I have met the same problem, and Ethan-yt's code doesn't work for me. I find that the rank that introduces around 2GB GPU memory overhead is the coordinator_rank for the torch's _DistWrapper.reduce_scatter in the saving or loading. And the overhead happens after torch.distributed.gather_object() in the _DistWrapper.reduce_scatter. But I haven't find how to solve this.

dibaotian-xing avatar Jun 23 '25 13:06 dibaotian-xing

I have met the same problem, and Ethan-yt's code doesn't work for me. I find that the rank that introduces around 2GB GPU memory overhead is the coordinator_rank for the torch's _DistWrapper.reduce_scatter in the saving or loading. And the overhead happens after torch.distributed.gather_object() in the _DistWrapper.reduce_scatter. But I haven't find how to solve this.

And I find that the memory overhead remains the same (2G) among different model sizes.

dibaotian-xing avatar Jun 23 '25 15:06 dibaotian-xing

I find that the overhead disappears after setting 'hetereogenous_dist_checkpoint' of transformer_config to True.

dibaotian-xing avatar Jul 07 '25 13:07 dibaotian-xing