Support zero-3 for FLUX training
Describe the bug
Due to memory limitations, I am attempting to use Zero-3 for Flux training on 8 GPUs with 32GB each. I encountered a bug similar to the one reported in this issue: https://github.com/huggingface/diffusers/issues/1865. I made modifications based on the solution proposed in this pull request: https://github.com/huggingface/diffusers/pull/3076. However, the same error persists. In my opinion, the fix does not work as expected, at least not entirely. Could you advise on how to modify it further?
The relevant code from https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1157 has been updated as follows:
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
print(f"deepspeed_plugin: {deepspeed_plugin}")
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
Reproduction
deepspeed config:
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps":"auto",
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"stage3_gather_16bit_weights_on_model_save": false,
"overlap_comm": false
},
"bf16": {
"enabled": true
},
"fp16": {
"enabled": false
}
}
accelerate config:
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: "config/ds_config.json"
distributed_type: DEEPSPEED
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
training shell:
#!/bin/bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux"
export DS_SKIP_CUDA_CHECK=1
export ACCELERATE_CONFIG_FILE="config/accelerate_config.yaml"
ACCELERATE_CONFIG_FILE_PATH=${1:-$ACCELERATE_CONFIG_FILE}
FLUXOUTPUT_DIR=flux_lora_output
mkdir -p $FLUXOUTPUT_DIR
accelerate launch --config_file $ACCELERATE_CONFIG_FILE_PATH train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=4 \
--guidance_scale=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-4 \
--report_to="tensorboard" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=100 \
--gradient_checkpointing \
--seed="0"
Logs
RuntimeError: 'weight' must be 2-D
System Info
pytorch: 2.1.0 deepspeed: 0.14.0 accelerate: 1.3.0 diffusers: develop
Who can help?
No response
lora + deepspeed won't work, unfortunately
lora + deepspeed won't work, unfortunately
@bghira did it work on megatron?
the problem is a bug in the interaction between Diffusers, Accelerate, PEFT, and DeepSpeed; which weren't involved for that training run of Megatron :D
the problem is a bug in the interaction between Diffusers, Accelerate, PEFT, and DeepSpeed; which weren't involved for that training run of Megatron :D
@bghira I see. Sorry for my expression, and my question is whether we can use megatron for Flux training on 8 GPUs with 32GB each, which haven't been mentioned in relation to any issues.
This bug is caused by the embedding layer of the text encoder being split into different Gpus. If the parameters of the text encoder are aggregated, the error will not be reported. However, doing so will result in only a slight drop in memory relative to zero stage 2. After gathering the encoder parameters, I still couldn't fine-tune all parameters.
with deepspeed.zero.GatheredParameters(text_encoders[0].parameters(), modifier_rank=None):
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
tokenizer=tokenizers[0],
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
)
This bug is caused by the embedding layer of the text encoder being split into different Gpus. If the parameters of the text encoder are aggregated, the error will not be reported. However, doing so will result in only a slight drop in memory relative to zero stage 2. After gathering the encoder parameters, I still couldn't fine-tune all parameters.
with deepspeed.zero.GatheredParameters(text_encoders[0].parameters(), modifier_rank=None): pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, )
Thx. I disable zero init with context manager, and not use this function.
the problem is a bug in the interaction between Diffusers, Accelerate, PEFT, and DeepSpeed; which weren't involved for that training run of Megatron :D
@bghira hi, I successfully ran the model using zero3 which only disable zero init on encoder models, and here is my modifications:
@contextmanager
def zero3_init_context_manager(deepspeed_plugin, enable=False):
old = deepspeed_plugin.zero3_init_flag
if old == enable:
yield
else:
deepspeed_plugin.zero3_init_flag = enable
deepspeed_plugin.dschf = None
yield
deepspeed_plugin.zero3_init_flag = old
deepspeed_plugin.dschf = None
deepspeed_plugin.set_deepspeed_weakref()
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
if deepspeed_plugin is None:
return []
return [zero3_init_context_manager(deepspeed_plugin, enable=False)]
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
And there an acceptable margin of error between zero2 and zero3.
However, I noticed an issue: this model doesn't split the model parameters before loading them onto GPUs. Instead, it loads the entire model during the accelerate.prepare step.
I discovered that this issue arises because diffusers differs from the BaseModel in transformers.
https://github.com/huggingface/transformers/blob/66f29aaaf55c8fe0c3dbcd24beede2ca4effac56/src/transformers/modeling_utils.py#L1336
https://github.com/huggingface/diffusers/blob/24c062aaa19f5626d03d058daf8afffa2dfd49f7/src/diffusers/models/modeling_utils.py#L252
Shoule we consider enhancing the implementation of the diffusers class?
follow
I recommend using pure deepspeed framework rather than accelerate+deepspeed.