如何再8*A100上预训练128k长度的llama3?
看README的图表是可以训练的,但是我一直OOM
import torch from datasets import load_dataset from mmengine.config import read_base from mmengine.dataset import DefaultSampler from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR from peft import LoraConfig from torch.optim import AdamW from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig)
from xtuner.dataset import process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop from xtuner.model import SupervisedFinetune
with read_base(): from .map_fn import pretrain_map_fn as dataset_map_fn
pretrained_model_name_or_path = '/opt/218/models/Meta-Llama-3-8B-Instruct-continue_pre'
data_path = './train_128k_1000.jsonl' max_length = 128000 pack_to_max_length = True
batch_size = 1 # per_device accumulative_counts = 1 dataloader_num_workers = 10 max_epochs = 3 optim_type = AdamW lr = 2e-4 betas = (0.9, 0.999) weight_decay = 1 max_norm = 1 # grad clip
save_steps = 100 save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)
Evaluate the generation performance during the training
evaluation_freq = 1000 SYSTEM = '' evaluation_inputs = [ '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' ]
tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, padding_side='right')
model = dict( type=SupervisedFinetune, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16, quantization_config=dict( type=BitsAndBytesConfig, load_in_4bit=True, load_in_8bit=False, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4')), lora=dict( type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.1, bias='none', task_type='CAUSAL_LM'))
train_dataset = dict( type=process_hf_dataset, dataset=dict( type=load_dataset, path='json', data_files=dict(train=data_path)), tokenizer=tokenizer, max_length=max_length, dataset_map_fn=dataset_map_fn, template_map_fn=None, remove_unused_columns=True, shuffle_before_pack=True, pack_to_max_length=pack_to_max_length)
train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=train_dataset, sampler=dict(type=DefaultSampler, shuffle=True), collate_fn=dict(type=default_collate_fn)) optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='float16') param_scheduler = dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, end=max_epochs, convert_to_iter_based=True)
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
custom_hooks = [dict(type=DatasetInfoHook, tokenizer=tokenizer)]
default_hooks = dict( timer=dict(type=IterTimerHook), logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), param_scheduler=dict(type=ParamSchedulerHook), checkpoint=dict( type=CheckpointHook, by_epoch=False, interval=save_steps, max_keep_ckpts=save_total_limit), sampler_seed=dict(type=DistSamplerSeedHook), )
env_cfg = dict( cudnn_benchmark=False, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), )
visualizer = None
log_level = 'INFO'
load_from = None
resume = False
randomness = dict(seed=None, deterministic=False)
log_processor = dict(by_epoch=False)
训练超长序列,需要使用序列并行
https://xtuner.readthedocs.io/zh-cn/docs/acceleration/train_extreme_long_sequence.html#id7