[BUG] Unexpected High Memory Usage (OOM) when finetuning Llama2-7B
Describe the bug I kept suffering from Out-of-Memory when I was doing a finetuning task for Llama2-7b-hf. I'm using two A100 80GB for the training, without any offloading. The dataset is very small, only about 0.5GB, (https://huggingface.co/datasets/wikitext). The training script is from Transformers example (transformers/examples/pytorch/language-modeling/run_clm.py). I noticed that the OoM happens at about the first optimizer step update (the first without overflow). Does anyone know what could be a possible reason? In principle, 2*80GB will be quite sufficient for training 7B (14G para, 14G grad, 84G optimizer). Batch size per device is 4.
To Reproduce
The command to finetune:
deepspeed --include localhost:x,x --master_port xxx\ $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ --deepspeed dsconfig/ds_config_profile_240508.json \ --model_name_or_path $MODEL_DIR \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --do_train \ --fp16 \ --learning_rate 2e-5 \ --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ --num_train_epochs $NEPOCHS \ --output_dir ${OUTPUT_DIR}_z3 \ --overwrite_output_dir \ --save_steps 0 \ --max_steps $MAX_STEPS \ --save_strategy "no"
The ds_config file:
{ "train_micro_batch_size_per_gpu": "auto", "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "none" }, "offload_param": { "device": "none" } }, "fp16": { "enabled": true, "initial_scale_power": 13 } }
More Information
- I receive the warning "[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3 [WARNING] using untested triton version (2.3.0), only 1.0.0 is known to be compatible". I'm not sure whether this will be problematic.
- I also tried other zero stage (0,2) but not working.
- With CPU offload enabled (ratio=0.3), the training works well.
- My torch, deepspeed are installed with pip (default version), the transformers is installed with the source from the Github.
CLI Output
[INFO|trainer.py:2078] 2024-05-08 10:39:03,977 >> ***** Running training ***** [INFO|trainer.py:2079] 2024-05-08 10:39:03,977 >> Num examples = 2,778 [INFO|trainer.py:2080] 2024-05-08 10:39:03,977 >> Num Epochs = 1 [INFO|trainer.py:2081] 2024-05-08 10:39:03,977 >> Instantaneous batch size per device = 4 [INFO|trainer.py:2084] 2024-05-08 10:39:03,977 >> Total train batch size (w. parallel, distributed & accumulation) = 8 [INFO|trainer.py:2085] 2024-05-08 10:39:03,977 >> Gradient Accumulation steps = 1 [INFO|trainer.py:2086] 2024-05-08 10:39:03,977 >> Total optimization steps = 30 [INFO|trainer.py:2087] 2024-05-08 10:39:03,978 >> Number of trainable parameters = 6,738,415,616