xtuner icon indicating copy to clipboard operation
xtuner copied to clipboard

如何再8*A100上预训练128k长度的llama3?

Open 1518630367 opened this issue 1 year ago • 2 comments

看README的图表是可以训练的,但是我一直OOM

1518630367 avatar May 13 '24 15:05 1518630367

import torch from datasets import load_dataset from mmengine.config import read_base from mmengine.dataset import DefaultSampler from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR from peft import LoraConfig from torch.optim import AdamW from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop from xtuner.model import SupervisedFinetune

with read_base(): from .map_fn import pretrain_map_fn as dataset_map_fn

pretrained_model_name_or_path = '/opt/218/models/Meta-Llama-3-8B-Instruct-continue_pre'

data_path = './train_128k_1000.jsonl' max_length = 128000 pack_to_max_length = True

batch_size = 1 # per_device accumulative_counts = 1 dataloader_num_workers = 10 max_epochs = 3 optim_type = AdamW lr = 2e-4 betas = (0.9, 0.999) weight_decay = 1 max_norm = 1 # grad clip

save_steps = 100 save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)

Evaluate the generation performance during the training

evaluation_freq = 1000 SYSTEM = '' evaluation_inputs = [ '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' ]

tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, padding_side='right')

model = dict( type=SupervisedFinetune, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16, quantization_config=dict( type=BitsAndBytesConfig, load_in_4bit=True, load_in_8bit=False, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4')), lora=dict( type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.1, bias='none', task_type='CAUSAL_LM'))

train_dataset = dict( type=process_hf_dataset, dataset=dict( type=load_dataset, path='json', data_files=dict(train=data_path)), tokenizer=tokenizer, max_length=max_length, dataset_map_fn=dataset_map_fn, template_map_fn=None, remove_unused_columns=True, shuffle_before_pack=True, pack_to_max_length=pack_to_max_length)

train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=train_dataset, sampler=dict(type=DefaultSampler, shuffle=True), collate_fn=dict(type=default_collate_fn)) optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='float16') param_scheduler = dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, end=max_epochs, convert_to_iter_based=True)

train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

custom_hooks = [dict(type=DatasetInfoHook, tokenizer=tokenizer)]

default_hooks = dict( timer=dict(type=IterTimerHook), logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), param_scheduler=dict(type=ParamSchedulerHook), checkpoint=dict( type=CheckpointHook, by_epoch=False, interval=save_steps, max_keep_ckpts=save_total_limit), sampler_seed=dict(type=DistSamplerSeedHook), )

env_cfg = dict( cudnn_benchmark=False, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), )

visualizer = None

log_level = 'INFO'

load_from = None

resume = False

randomness = dict(seed=None, deterministic=False)

log_processor = dict(by_epoch=False)

1518630367 avatar May 13 '24 15:05 1518630367

训练超长序列,需要使用序列并行

https://xtuner.readthedocs.io/zh-cn/docs/acceleration/train_extreme_long_sequence.html#id7

pppppM avatar May 14 '24 04:05 pppppM