diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[LoRA] fix: lora loading when using with a device_mapped model.

Open sayakpaul opened this issue 1 year ago • 8 comments

What does this PR do?

Fixes LoRA loading behaviour when used with a model that is sharded into multiple devices.

Minimal code
"""
Minimal example to show how to load a LoRA into the Flux transformer
that is sharded in two GPUs. 

Limitation:
* Latency
* If the LoRA has text encoder layers then this needs to be revisited.
"""

from diffusers import FluxTransformer2DModel, FluxPipeline 
import torch 

ckpt_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_pretrained(
    ckpt_id, 
    subfolder="transformer",
    device_map="auto",
    max_memory={0: "16GB", 1: "16GB"},
    torch_dtype=dtype
)
print(transformer.hf_device_map)
pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    vae=None,
    transformer=transformer,
    torch_dtype=dtype
)
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
# print(pipeline.transformer.hf_device_map)

# Essentially you'd pre-compute these embeddings beforehand.
# Reference: https://gist.github.com/sayakpaul/a9266fe2d0d510ec44a9cdc385b3dd74. 
example_inputs = {
    "prompt_embeds": torch.randn(1, 512, 4096, dtype=dtype, device="cuda"),
    "pooled_projections": torch.randn(1, 768, dtype=dtype, device="cuda"),
}

_ =  pipeline(
    prompt_embeds=example_inputs["prompt_embeds"],
    pooled_prompt_embeds=example_inputs["pooled_projections"],
    num_inference_steps=50,
    guidance_scale=3.5,
    height=1024,
    width=1024,
    output_type="latent",
)

Some internal discussions:

  • https://huggingface.slack.com/archives/C03UQJENJTV/p1725527760353639
  • https://huggingface.slack.com/archives/C04L3MWLE6B/p1726470631333599

Cc: @philschmid for awareness as you were interested in this feature.

TODOs

  • [x] Tests
  • [x] Docs

Once I get a sanity review from Marc and Benjamin, will request a review from Yiyi.

sayakpaul avatar Sep 17 '24 01:09 sayakpaul

Does diffusers have multi GPU tests? If yes, would it make sense to add a test there and check that after LoRA loading, no parameter was transferred to meta device?

BenjaminBossan avatar Sep 17 '24 10:09 BenjaminBossan

That is a TODO ;)

sayakpaul avatar Sep 17 '24 11:09 sayakpaul

Does diffusers have multi GPU tests?

@BenjaminBossan yes, we do: https://github.com/search?q=repo%3Ahuggingface%2Fdiffusers%20require_torch_multi_gpu&type=code

But not for the use case, being described here. Will add them as a part of this PR.

sayakpaul avatar Sep 17 '24 13:09 sayakpaul

@SunMarc a gentle ping when you find a moment.

sayakpaul avatar Sep 22 '24 10:09 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu can you give this an initial look and once we agree, I will work on adding testing, docs, etc.

sayakpaul avatar Sep 24 '24 14:09 sayakpaul

@yiyixuxu a gentle ping for a first review as it touches pipeline_utils.py.

sayakpaul avatar Oct 02 '24 13:10 sayakpaul

@DN6 @BenjaminBossan could you give this another look? I have added tests and docs.

sayakpaul avatar Oct 19 '24 12:10 sayakpaul

Failing tests are unrelated.

sayakpaul avatar Oct 31 '24 15:10 sayakpaul