[LoRA] fix: lora loading when using with a device_mapped model.
What does this PR do?
Fixes LoRA loading behaviour when used with a model that is sharded into multiple devices.
Minimal code
"""
Minimal example to show how to load a LoRA into the Flux transformer
that is sharded in two GPUs.
Limitation:
* Latency
* If the LoRA has text encoder layers then this needs to be revisited.
"""
from diffusers import FluxTransformer2DModel, FluxPipeline
import torch
ckpt_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id,
subfolder="transformer",
device_map="auto",
max_memory={0: "16GB", 1: "16GB"},
torch_dtype=dtype
)
print(transformer.hf_device_map)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
transformer=transformer,
torch_dtype=dtype
)
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
# print(pipeline.transformer.hf_device_map)
# Essentially you'd pre-compute these embeddings beforehand.
# Reference: https://gist.github.com/sayakpaul/a9266fe2d0d510ec44a9cdc385b3dd74.
example_inputs = {
"prompt_embeds": torch.randn(1, 512, 4096, dtype=dtype, device="cuda"),
"pooled_projections": torch.randn(1, 768, dtype=dtype, device="cuda"),
}
_ = pipeline(
prompt_embeds=example_inputs["prompt_embeds"],
pooled_prompt_embeds=example_inputs["pooled_projections"],
num_inference_steps=50,
guidance_scale=3.5,
height=1024,
width=1024,
output_type="latent",
)
Some internal discussions:
- https://huggingface.slack.com/archives/C03UQJENJTV/p1725527760353639
- https://huggingface.slack.com/archives/C04L3MWLE6B/p1726470631333599
Cc: @philschmid for awareness as you were interested in this feature.
TODOs
- [x] Tests
- [x] Docs
Once I get a sanity review from Marc and Benjamin, will request a review from Yiyi.
Does diffusers have multi GPU tests? If yes, would it make sense to add a test there and check that after LoRA loading, no parameter was transferred to meta device?
That is a TODO ;)
Does diffusers have multi GPU tests?
@BenjaminBossan yes, we do: https://github.com/search?q=repo%3Ahuggingface%2Fdiffusers%20require_torch_multi_gpu&type=code
But not for the use case, being described here. Will add them as a part of this PR.
@SunMarc a gentle ping when you find a moment.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@yiyixuxu can you give this an initial look and once we agree, I will work on adding testing, docs, etc.
@yiyixuxu a gentle ping for a first review as it touches pipeline_utils.py.
@DN6 @BenjaminBossan could you give this another look? I have added tests and docs.
Failing tests are unrelated.