ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

[BUG]: some bug about the train.py(https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/sequence_parallel/train.py)

Open lambda7xx opened this issue 3 years ago • 6 comments

🐛 Describe the bug

I change my config like below, and it runs on 1GPU. The global batch size = 1,

from colossalai.amp import AMP_TYPE

DATA_PATH = '/data/v-xxshi/coloss/raw_data/my-bert_text_sentence'
VOCAB_FILE_PATH = '/data/v-xxshi/coloss/raw_data/vocab/bert-large-uncased-vocab.txt'

# hyper-parameters
TRAIN_ITERS = 3
DECAY_ITERS = 20
WARMUP_FRACTION = 0.01
GLOBAL_BATCH_SIZE = 1   # dp world size * sentences per GPU
EVAL_ITERS = 3
EVAL_INTERVAL = 3
LR = 0.0001
MIN_LR = 1e-05
WEIGHT_DECAY = 0.01
SEQ_LENGTH = 2048

# BERT config
DEPTH = 24
NUM_ATTENTION_HEADS = 16
HIDDEN_SIZE = 1024

# model config
ADD_BINARY_HEAD = False

# random seed
SEED = 1234

# pipeline config
# only enabled when pipeline > 1
NUM_MICRO_BATCHES = 1

# colossalai config
parallel = dict(pipeline=1, tensor=dict(size=1, mode='sequence'))

fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True)

gradient_handler = [dict(type='SequenceParallelGradientHandler')]

Environment

No response

lambda7xx avatar Nov 16 '22 10:11 lambda7xx

Then I change the get_batch_for_sequence_parallel to make synthetic data.

def get_batch_for_sequence_parallel(data_iterator):
    global_rank = torch.distributed.get_rank()
    local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
    seq_length=gpc.config.SEQ_LENGTH
    dp_size = gpc.get_world_size(ParallelMode.DATA)
    global_batch_size = gpc.config.GLOBAL_BATCH_SIZE
    micro_batch_size = global_batch_size // dp_size
    print("**************dp_size:",dp_size," and micro_batch_size:",micro_batch_size)
    local_rank = global_rank % local_world_size
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
    sub_seq_end = (local_rank+1) * sub_seq_length

    """Build the batch."""
    

    data = torch.randint(0,100,(micro_batch_size, seq_length),requires_grad=False, device=torch.cuda.current_device())
    tokens = data[:,sub_seq_start:sub_seq_end].long()
    types = data[:,sub_seq_start:sub_seq_end].long()
    mask = torch.ones((micro_batch_size,seq_length),device=torch.cuda.current_device())
    loss_mask  = mask[:,sub_seq_start:sub_seq_end].float()
    padding_mask = (torch.rand((micro_batch_size, seq_length), requires_grad=False, device=torch.cuda.current_device()))
    padding_mask = torch.where(padding_mask < 0.5, 0, 1)
    lm_labels = data[:, sub_seq_start:sub_seq_end].long()
    sentence_order = torch.ones(micro_batch_size).long()
    
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask

lambda7xx avatar Nov 16 '22 10:11 lambda7xx

I run this code on 1GPU. when the GLOBAL_BATCH_SIZE = 1 , it can run. But when I set the GLOBAL_BATCH_SIZE = 32, 64, 128 or bigger, it is OOM.

So, assume that I want to run a GLOBAL_BATCH_SIZE=512, the micro_batch_size = 1, and we use the SP +PP+ DP to run the code. how to set the config? when I run the code on 1GPU/4GPU/8GPU/16GPU/32GPU.

lambda7xx avatar Nov 16 '22 10:11 lambda7xx

this is our understanding for the config.

Consider no Pipeline:
The meaning of global batch size: the total sample number of all GPU performs one forward.
The meaning of micro batch size: the total sample number of each GPU performs one forward.


Consider Pipeline:
The meaning of global batch size: the total sample number of all GPU performs num_microbatch times of forward.
The meaning of micro batch size: the total sample number of each GPU performs num_microbatch times of forward.

lambda7xx avatar Nov 16 '22 10:11 lambda7xx

@binmakeswell @feifeibear

lambda7xx avatar Nov 16 '22 10:11 lambda7xx

This issue has been stale for a long time. Global batch size = data parallel size * num_micro_batch * micro_batch_size.

FrankLeeeee avatar Nov 28 '22 02:11 FrankLeeeee

This issue has been stale for a long time. Global batch size = data parallel size * num_micro_batch * micro_batch_size.

I get it. BTW, is there any mistake about my dataloader?


`
def get_batch_for_sequence_parallel(data_iterator):
    global_rank = torch.distributed.get_rank()
    local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
    seq_length=gpc.config.SEQ_LENGTH
    dp_size = gpc.get_world_size(ParallelMode.DATA)
    global_batch_size = gpc.config.GLOBAL_BATCH_SIZE
    micro_batch_size = global_batch_size // dp_size
    print("**************dp_size:",dp_size," and micro_batch_size:",micro_batch_size)
    local_rank = global_rank % local_world_size
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
    sub_seq_end = (local_rank+1) * sub_seq_length

    """Build the batch."""
    

    data = torch.randint(0,100,(micro_batch_size, seq_length),requires_grad=False, device=torch.cuda.current_device())
    tokens = data[:,sub_seq_start:sub_seq_end].long()
    types = data[:,sub_seq_start:sub_seq_end].long()
    mask = torch.ones((micro_batch_size,seq_length),device=torch.cuda.current_device())
    loss_mask  = mask[:,sub_seq_start:sub_seq_end].float()
    padding_mask = (torch.rand((micro_batch_size, seq_length), requires_grad=False, device=torch.cuda.current_device()))
    padding_mask = torch.where(padding_mask < 0.5, 0, 1)
    lm_labels = data[:, sub_seq_start:sub_seq_end].long()
    sentence_order = torch.ones(micro_batch_size).long()
    
    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
`

lambda7xx avatar Nov 28 '22 05:11 lambda7xx

We have updated a lot. This issue was closed due to inactivity. Thanks.

binmakeswell avatar Apr 13 '23 10:04 binmakeswell