ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

[BUG]: HybridParallelOptimizer holds unsharded model parameters after sharding

Open insujang opened this issue 1 year ago • 9 comments

🐛 Describe the bug

When using tensor parallelism, model parameters are sharded across GPUs to reduce its memory consumption and parallel execution. However, the optimizer still holds unsharded model parameters, preventing the old unsharded parameters from being released, taking more memory.

Example code: (adopted from examples/language/gpt2/hybridparallelism/finetune.py)

colossalai.launch_from_torch(config={})
plugin = HybridParallelPlugin(tp_size=4, pp_size=1)
optimizer = Adam(model.parameters())
# initialize dataloader

model, optimizer, *_ = booster.booster(model, optimizer, ...)
> model.module.transformer.wte.weight
Parameter containing:
tensor([[-0.1101, -0.0393, ...]], device='cuda:0', dtype=torch.float16, requires_grad=True)

> model.module.transformer.wte.weight.shape
torch.Size([12565, 768])

> optimizer.param_groups[0]["params"][0]
Parameter containing:
tensor([[-0.1101, -0.0393, ...]], device='cuda:0', requires_grad=True)

> optimizer.param_groups[0]["params"[0].shape
torch.Size([50257, 768])

This also affects MixedPrecisionOptimizer.master_to_working_map and MixedPrecisionOptimizer.working_to_master_map:

# model.module.transformer.wte.weight is supposed to be in a working parameter
> model.module.transformer.wte.weight.shape
torch.Size([12565, 768])
> id(model.module.transformer.wte.weight)
139684649437120

# First working parameter in map does not refer to this
> list(iter(optimizer.master_to_working_map))[0].shape
torch.Size([50257, 768])
> id(list(iter(optimizer.master_to_working_map))[0])
139693862695728

Because of this it seems only a portion of parameters (ie. unsharded ones) only trained, as MixedPrecisionOptimizer.step() skips sharded parameters as gradients are not stored in mismatched unsharded parameters:

https://github.com/hpcaitech/ColossalAI/blob/df5e9c53cf23d44656470cc319ee0b470c40712f/colossalai/amp/naive_amp/mixed_precision_optimizer.py#L173-L175

Environment

PyTorch 2.2.1 / CUDA 12.1

insujang avatar Mar 31 '24 02:03 insujang

Hi, thanks for the issue. I reproduced the bug using this script finetune.zip This might be due to some unexpected model movement without ZeRO. Mostly ZeRO is used and the params are sharded in-place. I'm looking into this.

Edenzzzz avatar Apr 01 '24 08:04 Edenzzzz

This happens only when sequence parallel is on and ZeRO is off. We are rebuilding the seq parallel API with ring attention etc., so I've set it to False in enable_all_optimization as a quick fix.

Edenzzzz avatar Apr 03 '24 03:04 Edenzzzz

@Edenzzzz , thank you for your time looking into this issue. I am not sure if this fix works. I tested with enable_all_optimization=False, enable_sequence_parallelism=False, and enable_sequence_overlap=False, still the same problem happens from my side. Could you check again?

Edit: this is my plugin configuration used:

plugin = HybridParallelPlugin(
            tp_size=4,
            pp_size=1,
            num_microbatches=None,
            microbatch_size=1,
            enable_all_optimization=False,
            enable_sequence_parallelism=False,
            enable_sequence_overlap=False,
            zero_stage=0,
            precision="fp16",
            initial_scale=1,
        )

insujang avatar Apr 03 '24 14:04 insujang

This bug seems specific to a minority of TP plans. Will take another look image

Edenzzzz avatar Apr 06 '24 14:04 Edenzzzz

Looks like preprocess in each policy might be the reason:

https://github.com/hpcaitech/ColossalAI/blob/341263df48bbef1174c41b6c4f5f6785f895b0d4/colossalai/shardformer/policies/bert.py#L39-L51

https://github.com/hpcaitech/ColossalAI/blob/341263df48bbef1174c41b6c4f5f6785f895b0d4/colossalai/shardformer/policies/gpt2.py#L32-L43

Although all policies have the same resize logic, each model has different default vocab embedding size, so only bert and gpt2 in your tests need resizing embedding, which create a new one and fail:

from transformers import AutoConfig

def test_vocab_size_divisible_to_tp_size(model_name: str, tp_size: int):
    config = AutoConfig.from_pretrained(model_name)
    vocab_size = config.vocab_size

    print(f"model {model_name} vocab_size: {vocab_size}. Need to resize embeddings for tp degree {tp_size}? {vocab_size % tp_size != 0}")

test_vocab_size_divisible_to_tp_size("gpt2", 8)
test_vocab_size_divisible_to_tp_size("bert-base-uncased", 8)
test_vocab_size_divisible_to_tp_size("facebook/opt-125m", 8)
test_vocab_size_divisible_to_tp_size("tiiuae/falcon-rw-1b", 8)
model gpt2 vocab_size: 50257. Need to resize embeddings for tp degree 8? True
model bert-base-uncased vocab_size: 30522. Need to resize embeddings for tp degree 8? True
model facebook/opt-125m vocab_size: 50272. Need to resize embeddings for tp degree 8? False
model tiiuae/falcon-rw-1b vocab_size: 50304. Need to resize embeddings for tp degree 8? False

It creates a complete new nn.Embedding and therefore their ID becomes different: https://github.com/huggingface/transformers/blob/76fa17c1663a0efeca7208c20579833365584889/src/transformers/modeling_utils.py#L2017-L2028

# Before calling `preprocess()` on gpt2:
id(model.transformer.wte.weight)
140670116084640
model.transformer.wte
Embedding(50257, 768)

# After calling `preprocess()` on gpt2:
id(model.transformer.wte.weight)
140670118343072
model.transformer.wte
Embedding(50260, 768)

insujang avatar Apr 08 '24 03:04 insujang

A quick potential patch is not to use HF's resize_token_embeddings and use nn.functional.pad to resize tensor while avoiding recreation of nn.Embedding (not sure if there are other attributes that should also be modified):

def resize_token_embedding_inplace(num_new_tokens: int, embedding: nn.Embedding):
    # In-place resize of the token embeddings
    embedding.num_embeddings = new_num_tokens
    embedding.weight.data = nn.functional.pad(
        embedding.weight.data,
        (0, 0, 0, new_num_tokens - embedding.weight.size(0)),
        "constant",
        0,
    )
 
# In policy
def preprocess(self):
    # reshape the embedding layer
    r"""
    Reshape the Embedding layer to make the embedding dimension divisible by world_size
    """
    if self.shard_config.enable_tensor_parallelism:
        vocab_size = self.model.config.vocab_size
        world_size = self.shard_config.tensor_parallel_size
        if vocab_size % world_size != 0:
            new_vocab_size = vocab_size + world_size - vocab_size % world_size

            resize_token_embedding_inplace(new_vocab_size, self.model.get_input_embeddings())
            # self.model.resize_token_embeddings(new_vocab_size)

    return self.model

@Edenzzzz Could you please check if it works? Thanks

insujang avatar Apr 08 '24 04:04 insujang

Maybe it is related to #5489 ?

insujang avatar Apr 08 '24 13:04 insujang

A quick potential patch is not to use HF's resize_token_embeddings and use nn.functional.pad to resize tensor while avoiding recreation of nn.Embedding (not sure if there are other attributes that should also be modified):

def resize_token_embedding_inplace(num_new_tokens: int, embedding: nn.Embedding):
    # In-place resize of the token embeddings
    embedding.num_embeddings = new_num_tokens
    embedding.weight.data = nn.functional.pad(
        embedding.weight.data,
        (0, 0, 0, new_num_tokens - embedding.weight.size(0)),
        "constant",
        0,
    )
 
# In policy
def preprocess(self):
    # reshape the embedding layer
    r"""
    Reshape the Embedding layer to make the embedding dimension divisible by world_size
    """
    if self.shard_config.enable_tensor_parallelism:
        vocab_size = self.model.config.vocab_size
        world_size = self.shard_config.tensor_parallel_size
        if vocab_size % world_size != 0:
            new_vocab_size = vocab_size + world_size - vocab_size % world_size

            resize_token_embedding_inplace(new_vocab_size, self.model.get_input_embeddings())
            # self.model.resize_token_embeddings(new_vocab_size)

    return self.model

@Edenzzzz Could you please check if it works? Thanks

Thanks for the nice catch! This worked for both gpt2 and bert. Yes some fix appears to be in progress. Will touch base with them tomorrow.

Edenzzzz avatar Apr 08 '24 13:04 Edenzzzz

Nice to hear that the fix will be merged very soon. Thank you!

insujang avatar Apr 08 '24 13:04 insujang