NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Exception running inference with MCore Distributed Checkpoint with different TP setting than training

Open ryxli opened this issue 1 year ago • 6 comments

Describe the bug

A clear and concise description of what the bug is.

I have an mcore distributed checkpoint trained with PP=1, TP=1. When running inference with this distributed checkpoint, when I set the TP to higher than 1, it results in exceptions and inconsistent hangs.

When running inference with mcore distributed checkpoint with a tp > 1, there is an exception raised for:

Traceback (most recent call last):
  File "/workspace/src/3rdparty/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py", line 271, in main
    model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer, **kwargs)
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/models/nlp_model.py", line 397, in load_from_checkpoint
    checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
  File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 92, in load
    validate_sharding_integrity(nested_values(sharded_state_dict))
  File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 306, in validate_sharding_integrity
    _validate_sharding_for_key(shardings)
  File "/workspace/src/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 344, in _validate_sharding_for_key
    raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
megatron.core.dist_checkpointing.core.CheckpointingException: Invalid access pattern for ShardedTensor(key='model.embedding.word_embeddings.weight')



Error executing job with overrides: ['inference.greedy=True', 'inference.add_BOS=True', 'trainer.devices=8', 'trainer.num_nodes=1', 'tensor_model_parallel_size=8', 'pipeline_model_parallel_size=1']
Traceback (most recent call last):
  File "/workspace/src/3rdparty/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py", line 308, in main
    response = model.generate(
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1414, in generate
    return megatron_gpt_generate(
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 127, in megatron_gpt_generate
    output = generate(
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 645, in generate
    output = synced_generate(
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 510, in synced_generate
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
  File "/workspace/src/3rdparty/NeMo/nemo/collections/nlp/modules/common/text_generation_utils.py", line 888, in sample_sequence_batch
    torch.distributed.broadcast(done, src, group)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 1947, in broadcast
    work = group.broadcast([tensor], opts)
torch.distributed.DistNetworkError: Broken pipe

Steps/Code to reproduce bug

Please list minimal steps or code snippet for us to be able to reproduce the bug.

A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.

  1. Have an existing distributed mcore gpt checkpoint saved in a directory trained with TP=1, PP=1
  2. Pass checkpoint_dir, checkpoint_name into megatron_gpt_eval.py
  3. [Optional] In my case, I enabled activation check pointing so I also had to pass kwargs["activations_checkpoint_granularity"] = None and kwargs["activations_checkpoint_method"] = None into MegatronGPTModel.load_from_checkpoint
  4. Run inference with TP > 1 on a device with multiple gpus

This results in an exception everytime, and a percentage of runs are able to complete, but most of the time the process ends up hanging.

Expected behavior

A clear and concise description of what you expected to happen.

With mcore distributed checkpointing, I expect to be able to load an mcore model with different model parallel configs without any error using the example scripts for inference.

Environment overview (please complete the following information)

  • Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)] AWS p4de.24xlarge (A100)
  • Method of NeMo install: [pip install or from source]. From source, using NeMo r1.23.0 branch and dependencies
  • If method of install is [Docker], provide docker pull & docker run commands used

Environment details

If NVIDIA docker image is used you don't need to specify these. Otherwise, please provide:

  • OS version - Ubuntu 22.04.3 LTS (pytorch 23.12 container)
  • PyTorch version - 2.2.0a0+81ea7a4
  • Python version - Python 3.10.12

Additional context

Attaching full logs in files.

Add any other context about the problem here. Example: GPU model tp8_error.log

ryxli avatar Feb 20 '24 02:02 ryxli

@dimapihtar @ericharper Same issue occurs when trying to load the distributed checkpoint for continued training / sft.

Loading distributed checkpoint with a single A100 works fine, with gbs=1,tp=1,pp=1,mbs=1. When scaling to 8 gpus, and changing gbs to 8, loading the checkpoint fails.

Has the team been able to reproduce this internally? Currently using the following commits: Megatron-LM ad53b1e38689a0ceed75ade7821f4e6c7554abb4 NeMo 9b64e390b534d4eb5ad7f28502bcfe4c7f0c6c39 TransformerEngine: da30634a6c9ccdbb6c587b6c93b1860e4b038204

[0]:    model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name)
[0]:  File "/workspace/src/3rdparty/NeMo/build/__editable__.nemo_toolkit-1.23.0rc0-py3-none-any/nemo/collections/nlp/models/nlp_model.py", line 397, in load_from_checkpoint
[0]:    checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
[0]:  File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 99, in load
[0]:    validate_sharding_integrity(nested_values(sharded_state_dict))
[0]:  File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 346, in validate_sharding_integrity
[0]:    _validate_sharding_for_key(shardings)
[0]:  File "/workspace/src/3rdparty/Megatron-LM/build/__editable__.megatron_core-0.5.0rc0-cp310-cp310-linux_x86_64/megatron/core/dist_checkpointing/serialization.py", line 384, in _validate_sharding_for_key
[0]:    raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
[0]:megatron.core.dist_checkpointing.core.CheckpointingException: Invalid access pattern for ShardedTensor(key='model.embedding.word_embeddings.weight')

ryxli avatar Mar 20 '24 03:03 ryxli

Some debug logs as well, let me know if anything else could be useful to include:

> rank_sharing[0][1]

[(0, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (1, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (2, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (3, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (4, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (5, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (6, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None)), (7, ShardedTensor(key='model.embedding.word_embeddings.weight', data=None, dtype=torch.float32, local_shape=(50304, 768), global_shape=(50304, 768), global_offset=(0, 0), axis_fragmentations=(1, 1), replica_id=(0, 0, 0), prepend_axis_num=0, allow_shape_mismatch=True, flattened_range=None))]
> torch.zeros(rank_sharding[0][1].axis_fragmentations)
tensor([[0.]])

The assertion is checking that the shard is access once on all ranks:

        if not torch.all(shard_access_cnt == 1):
            logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
            raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')

But this is the behavior I observe:

> shard_access_cnt
tensor([[8]], dtype=torch.int32)

>for rank, sharding in rank_sharding:
    print(sharding.replica_id)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)
(0, 0, 0)

Notice that there is a TODO listed to check the shard_access_cnt of replicas as well: https://github.com/NVIDIA/Megatron-LM/blob/0fecd76e995c136021d478c6c52caa57c2f9aa25/megatron/core/dist_checkpointing/serialization.py#L444C1-L447C59

But there should only be one replica per rank, which is the one above.

Since I'm setting tensor parallel to 1 here, and gbs to 8, I think that the embedding_weights should not be expected to be sharded since there is only a single TP group with worldsize=8. I believe that during training, this dist ckpt also only used tp=1. So it should be fine if the embedding weights are fully replicated across all of the ranks?

Off topic: I also notice that when loading from the distributed checkpoint, it loads the saved optimizer states as well, which is probably expected behavior. Is there any reference or guidance on how to load without these saved optimizer states? Or does it not matter if it gets overwritten later on?

ryxli avatar Mar 20 '24 03:03 ryxli

Seem to have figured out the root cause.

Background

When loading from distributed checkpoint, we call NLPModel.load_from_checkpoint(..) ref For distributed checkpoint, loading the state_dict gets deferred until the class is initialized

The current logic is as follows:

            if 'cfg' in kwargs:
                model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs)
            else:
                model = ptl_load_state(cls, checkpoint, strict=strict, cfg=cfg, **kwargs)
                # cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg

            if checkpoint_dir is not None:
                sharded_state_dict = model.sharded_state_dict()
                checkpoint['state_dict'] = sharded_state_dict
                # dist checkpointing needs torch.distributed to load the checkpoint
                if parallel_state.is_unitialized():

                    def dummy():
                        return

                    if model.trainer.strategy.launcher is not None:
                        model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
                    model.trainer.strategy.setup_environment()

For MegatronGPTModel(MCoreGPTModel), this instantiates the LanguageModelEmbedding() if cfg.pre_process is true.

On this line sharded_state_dict = model.sharded_state_dict(), we call MegatronGPTModel.sharded_state_dict(...) which in turn calls VocabParallelEmbedding.sharded_state_dict(...):

    def sharded_state_dict(
        self, prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = ()
    ) -> ShardedStateDict:
        """ Non-default implementation for embeddings due to `allow_shape_mismatch` param """
        state_dict = self.state_dict(prefix='', keep_vars=True)

        weight_prefix = f'{prefix}weight'
        return {
            weight_prefix: make_tp_sharded_tensor_for_checkpoint(
                tensor=state_dict['weight'],
                key=weight_prefix,
                allow_shape_mismatch=True,
                prepend_offsets=sharded_offsets,
            )
        }

make_tp_sharded_tensor_for_checkpoint:

def make_tp_sharded_tensor_for_checkpoint(
    tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
    prepend_axis_num = len(prepend_offsets)
    if replica_id is None:
        replica_id = (0, 0, parallel_state.get_data_parallel_rank(with_context_parallel=True))
    ...

The Issue

At this point, parallel_state is not yet available since torch.distributed.is_initialized() == False, so any data_parallel ranks return as 0 and incorrectly sets all replica_ids as (0,0,0), which in turn causes the assertion mentioned above to fail when validating the sharded tensors.

Potential Resolution

It seems like an easy fix would be to move the initialization logic before the model.sharded_state_dict() call, but I'm not informed enough to know what the downstream impact would be.

            if checkpoint_dir is not None:
                sharded_state_dict = model.sharded_state_dict()
                checkpoint['state_dict'] = sharded_state_dict
                # dist checkpointing needs torch.distributed to load the checkpoint
                if parallel_state.is_unitialized():

                    def dummy():
                        return

                    if model.trainer.strategy.launcher is not None:
                        model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
                    model.trainer.strategy.setup_environment()

Can you please let me know if this is a correct understanding?

@dimapihtar @ericharper

ryxli avatar Mar 20 '24 22:03 ryxli

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] avatar Apr 20 '24 01:04 github-actions[bot]

any updates on this issue?

ryxli avatar Apr 20 '24 02:04 ryxli

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] avatar May 21 '24 01:05 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar May 28 '24 01:05 github-actions[bot]