deep_speed initialization for models in the transformers library
Dear authors,
I found that collie can not initialize DeepSpeed when using models in the transformers library. For example, when replace this line of script with the from_pretrained interface of the transformers library, to which any config of the type CollieConfig can not be passed, even the monitors can not be registered correctly since ds is not initialized (DeepSpeed backend not set, please initialize it using init_process_group()). Is there any workaround of this issue or Collie can only support training the internally reimplemented models?
Hi @DesperateExplorer , Collie can use models from transformers, in the case of ZeRO parallelism. But you need to execute setup_distribution manually:
from collie import setup_distribution, CollieConfig
from transformers import AutoModelForCausalLM
model_name = "openlm-research/open_llama_7b_v2"
config = CollieConfig.from_pretrianed(model_name)
setup_distribution(config)
model = AutoModelForCausalLM.from_pretrained(model_name)
Why is the memory consumption of the LLaMA-7B from transformers much larger than the internal implementation by Collie? Taking LLaMA-7B and AdamW for example, when using the internal implementation, train_micro_batch_size_per_gpu can be 2 and will not cause OOM for V100 on the ShareGPT dataset (max context = 2048), however, when using the transformers implementation, "train_micro_batch_size_per_gpu = 1" will cause OOM. Even switching to Lomo, I can not fit "train_micro_batch_size_per_gpu = 1" sample into the 32GB memory without OOM.
Collie's LLaMA used flash attetion as MHA, which can reduce memory usage. If your use_flash is True, the memory usage is less than transformers implementation
Collie's LLaMA used
flash attetionas MHA, which can reduce memory usage. If youruse_flashis True, the memory usage is less than transformers implementation
Actually, not. On V100 (Volta architecture), any kind of flash attention is not supported.
You can try to set the pretrained_config.gradient_checkpointing to True, just like this:
You can try to set the
pretrained_config.gradient_checkpointingto True, just like this:
config.checkpointing=True also works now.
