diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Redundant reinitialization of text encoders in train_dreambooth_lora_flux

Open CyberDragon93 opened this issue 1 year ago • 2 comments

Describe the bug

In the train_dreambooth_lora_flux.py script, during each call to log_validation, the text encoders text_encoder_one and text_encoder_two are reinitialized. https://github.com/huggingface/diffusers/blob/8ba90aa706a733f45d83508a5b221da3c59fe4cd/examples/dreambooth/train_dreambooth_lora_flux.py#L1768 This occurs even when the text encoders do not need to be trained (if not args.train_text_encoder). This unnecessary reinitialization can lead to inefficient use of resources and may cause CUDA out-of-memory errors, especially in scenarios where VRAM is less than 48 GiB.

Since the validation prompt is fixed (only one prompt is used), we can optimize the process by precomputing the text embeddings during the instance prompt preprocessing. This would allow the model to fit within 40 GiB of VRAM, preventing CUDA OOM issues.

Proposed Fix

To address this issue, add the following code snippet to precompute the validation prompt embeddings only once when the text encoders do not need to be trained and custom instance prompts are not used:

if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
    instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(
        args.instance_prompt, text_encoders, tokenizers
    )

    if args.validation_prompt is not None:
        validation_prompt_hidden_states, validation_pooled_prompt_embeds, _ = compute_text_embeddings(
            args.validation_prompt, text_encoders, tokenizers
        )

This change will prevent the unnecessary reinitialization of text encoders and reduce the VRAM usage during training.

Reproduction

from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir, repo_type="dataset",
    ignore_patterns=".gitattributes",
)

export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux-lora"

accelerate launch train_dreambooth_lora_flux.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-5 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

Logs

No response

System Info

  • 🤗 Diffusers version: 0.31.0.dev0
  • Platform: Linux-5.15.0-47-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.8
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.24.5
  • Transformers version: 4.43.4
  • Accelerate version: 0.30.1
  • PEFT version: 0.12.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A40, 49140 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul

CyberDragon93 avatar Sep 03 '24 17:09 CyberDragon93

Cc: @linoytsaban

sayakpaul avatar Sep 04 '24 01:09 sayakpaul

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Oct 17 '24 15:10 github-actions[bot]