Exception running inference with MCore Distributed Checkpoint with different TP setting than training
Describe the bug
A clear and concise description of what the bug is.
I have an mcore distributed checkpoint trained with PP=1, TP=1. When running inference with this distributed checkpoint, when I set the TP to higher than 1, it results in exceptions and inconsistent hangs.
When running inference with mcore distributed checkpoint with a tp > 1, there is an exception raised for:
Traceback (most recent call last):
File "/workspace/src/3rdparty/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py", line 271, in main
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer, **kwargs)
File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/models/nlp_model.py", line 397, in load_from_checkpoint
checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 92, in load
validate_sharding_integrity(nested_values(sharded_state_dict))
File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 306, in validate_sharding_integrity
_validate_sharding_for_key(shardings)
File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 344, in _validate_sharding_for_key
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
megatron.core.dist_checkpointing.core.CheckpointingException: Invalid access pattern for ShardedTensor(key='model.embedding.word_embeddings.weight')
Error executing job with overrides: ['inference.greedy=True', 'inference.add_BOS=True', 'trainer.devices=8', 'trainer.num_nodes=1', 'tensor_model_parallel_size=8', 'pipeline_model_parallel_size=1']
Traceback (most recent call last):
File "/workspace/src/3rdparty/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py", line 308, in main
response = model.generate(
File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1414, in generate
return megatron_gpt_generate(
File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 127, in megatron_gpt_generate
output = generate(
File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 645, in generate
output = synced_generate(
File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 510, in synced_generate
for tokens, lengths, output_logits, full_logits in batch_token_iterator:
File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 888, in sample_sequence_batch
torch.distributed.broadcast(done, src, group)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1947, in broadcast
work = group.broadcast([tensor], opts)
torch.distributed.DistNetworkError: Broken pipe
Steps/Code to reproduce bug
Please list minimal steps or code snippet for us to be able to reproduce the bug.
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
- Have an existing distributed mcore gpt checkpoint saved in a directory trained with TP=1, PP=1
- Pass checkpoint_dir, checkpoint_name into megatron_gpt_eval.py
- [Optional] In my case, I enabled activation check pointing so I also had to pass kwargs["activations_checkpoint_granularity"] = None and kwargs["activations_checkpoint_method"] = None into
MegatronGPTModel.load_from_checkpoint - Run inference with TP > 1 on a device with multiple gpus
This results in an exception everytime, and a percentage of runs are able to complete, but most of the time the process ends up hanging.
Expected behavior
A clear and concise description of what you expected to happen.
With mcore distributed checkpointing, I expect to be able to load an mcore model with different model parallel configs without any error using the example scripts for inference.
Environment overview (please complete the following information)
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)] AWS p4de.24xlarge (A100)
- Method of NeMo install: [pip install or from source]. From source, using NeMo r1.23.0 branch and dependencies
- If method of install is [Docker], provide
docker pull&docker runcommands used
Environment details
If NVIDIA docker image is used you don't need to specify these. Otherwise, please provide:
- OS version - Ubuntu 22.04.3 LTS (pytorch 23.12 container)
- PyTorch version - 2.2.0a0+81ea7a4
- Python version - Python 3.10.12
Additional context
Attaching full logs in files.
Add any other context about the problem here. Example: GPU model tp8_error.log
@dimapihtar @ericharper Same issue occurs when trying to load the distributed checkpoint for continued training / sft.
Loading distributed checkpoint with a single A100 works fine, with gbs=1,tp=1,pp=1,mbs=1. When scaling to 8 gpus, and changing gbs to 8, loading the checkpoint fails.
Has the team been able to reproduce this internally? Currently using the following commits: Megatron-LM ad53b1e38689a0ceed75ade7821f4e6c7554abb4 NeMo 9b64e390b534d4eb5ad7f28502bcfe4c7f0c6c39 TransformerEngine: da30634a6c9ccdbb6c587b6c93b1860e4b038204
[0]: model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name)
[0]: File "/workspace/src/3rdparty/NeMo/build/__editable__.nemo_toolkit-1.23.0rc0-py3-none-any/nemo/collections/nlp/models/nlp_model.py", line 397, in load_from_checkpoint
[0]: checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
[0]: File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 99, in load
[0]: validate_sharding_integrity(nested_values(sharded_state_dict))
[0]: File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 346, in validate_sharding_integrity
[0]: _validate_sharding_for_key(shardings)
[0]: File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 384, in _validate_sharding_for_key
[0]: raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
[0]:megatron.core.dist_checkpointing.core.CheckpointingException: Invalid access pattern for ShardedTensor(key='model.embedding.word_embeddings.weight')
Some debug logs as well, let me know if anything else could be useful to include:
> rank_sharing[0][1]
[(0, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (1, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (2, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (3, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (4, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (5, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (6, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (7, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None))]
> torch.zeros(rank_sharding[0][1].axis_fragmentations)
tensor([[0.]])
The assertion is checking that the shard is access once on all ranks:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
But this is the behavior I observe:
> shard_access_cnt
tensor([[8]], dtype=torch.int32)
>for rank, sharding in rank_sharding:
print(sharding.replica_id)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
Notice that there is a TODO listed to check the shard_access_cnt of replicas as well: https://github.com/NVIDIA/Megatron-LM/blob/0fecd76e995c136021d478c6c52caa57c2f9aa25/megatron/core/dist_checkpointing/serialization.py#L444C1-L447C59
But there should only be one replica per rank, which is the one above.
Since I'm setting tensor parallel to 1 here, and gbs to 8, I think that the embedding_weights should not be expected to be sharded since there is only a single TP group with worldsize=8. I believe that during training, this dist ckpt also only used tp=1. So it should be fine if the embedding weights are fully replicated across all of the ranks?
Off topic: I also notice that when loading from the distributed checkpoint, it loads the saved optimizer states as well, which is probably expected behavior. Is there any reference or guidance on how to load without these saved optimizer states? Or does it not matter if it gets overwritten later on?
Seem to have figured out the root cause.
Background
When loading from distributed checkpoint, we call NLPModel.load_from_checkpoint(..) ref
For distributed checkpoint, loading the state_dict gets deferred until the class is initialized
The current logic is as follows:
if 'cfg' in kwargs:
model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs)
else:
model = ptl_load_state(cls, checkpoint, strict=strict, cfg=cfg, **kwargs)
# cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg
if checkpoint_dir is not None:
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
# dist checkpointing needs torch.distributed to load the checkpoint
if parallel_state.is_unitialized():
def dummy():
return
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
For MegatronGPTModel(MCoreGPTModel), this instantiates the LanguageModelEmbedding() if cfg.pre_process is true.
On this line sharded_state_dict = model.sharded_state_dict(), we call MegatronGPTModel.sharded_state_dict(...) which in turn calls VocabParallelEmbedding.sharded_state_dict(...):
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = ()
) -> ShardedStateDict:
""" Non-default implementation for embeddings due to `allow_shape_mismatch` param """
state_dict = self.state_dict(prefix='', keep_vars=True)
weight_prefix = f'{prefix}weight'
return {
weight_prefix: make_tp_sharded_tensor_for_checkpoint(
tensor=state_dict['weight'],
key=weight_prefix,
allow_shape_mismatch=True,
prepend_offsets=sharded_offsets,
)
}
make_tp_sharded_tensor_for_checkpoint:
def make_tp_sharded_tensor_for_checkpoint(
tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
prepend_axis_num = len(prepend_offsets)
if replica_id is None:
replica_id = (0, 0, parallel_state.get_data_parallel_rank(with_context_parallel=True))
...
The Issue
At this point, parallel_state is not yet available since torch.distributed.is_initialized() == False, so any data_parallel ranks return as 0 and incorrectly sets all replica_ids as (0,0,0), which in turn causes the assertion mentioned above to fail when validating the sharded tensors.
Potential Resolution
It seems like an easy fix would be to move the initialization logic before the model.sharded_state_dict() call, but I'm not informed enough to know what the downstream impact would be.
if checkpoint_dir is not None:
sharded_state_dict = model.sharded_state_dict()
checkpoint['state_dict'] = sharded_state_dict
# dist checkpointing needs torch.distributed to load the checkpoint
if parallel_state.is_unitialized():
def dummy():
return
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
Can you please let me know if this is a correct understanding?
@dimapihtar @ericharper
This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.
any updates on this issue?
This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.
This issue was closed because it has been inactive for 7 days since being marked as stale.