[Design Discussion] allowing `from_pretrained()` to also load single file checkpoints
Since we were considering adding an option like single_file_format to save_pretrained() of DiffusionPipeline, it makes sense to have something similar in from_pretrained() to have better feature parity.
We currently support loading single file checkpoints in DiffusionPipeline via from_single_file(). Some examples below:
from diffusers import StableDiffusionPipeline
# Download pipeline from huggingface.co and cache.
pipeline = StableDiffusionPipeline.from_single_file(
"https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
)
# Download pipeline from local file
# file is downloaded under ./v1-5-pruned-emaonly.ckpt
pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
# Enable float16 and move to GPU
pipeline = StableDiffusionPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
torch_dtype=torch.float16,
)
pipeline.to("cuda")
(Taken from the docs here)
Proposed API design
Calling from_pretrained() on a DiffusionPipeline requires users to mandatorily pass pretrained_model_name_or_path, which can be a repo id on the Hub or a local directory containing checkpoints in the diffusers format.
(Docs)
Now, if we want to add support for loading a compatible single file checkpoint in from_pretrained(), we could have an API like so:
from diffusers import DiffusionPipeline
repo_id = "WarriorMama777/OrangeMixs"
pipe = DiffusionPipeline.from_pretrained(repo_id, weight_name="Models/AbyssOrangeMix/AbyssOrangeMix.safetensors")
- Like before,
repo_idcould either be an actual repo id on the Hub or a local directory. -
weight_namecan either be just the filename of the single file checkpoint to be loaded or the relative path to the checkpoint (w.r.t the underlying repo / directory). - When
weight_nameis provided infrom_pretrained():- We immediately check if the file exists in the repository or the directory and flag an error if necessary in case it's not found.
- Once it's checked, we hit the codepath that we're hitting currently when using
from_single_file(). Logic to do that should be completely separated as a utility and should not come intofrom_pretrained(). We can just call the utility fromfrom_pretrained(). - How can we detect errors here as early as possible? What if the checkpoint is not compatible or doesn't have all the components we need (what if the
vaeor any other component is missing)? Is there any robust way?
- Once this support is foolproof, we can start deprecating the use of
from_single_file().
Some thoughts
- I don't think this is a very new design. Users are already familiar with
weight_nameand how it's to be used throughload_lora_weights()(which is quite popular at this point IMO). - I think we must force users to pass
weight_name. Too much intelligent guessing here would lead to ugly consequences in the code and I am not sure if it's worth bearing the fruits for.
Cc: @patrickvonplaten @DN6
weight_name or file_name makes sense to me! Let's maybe make sure we have the same loading logic here as in load_lora
Let's maybe make sure we have the same loading logic here as in load_lora
Elaboate a bit?
Hello, a dev from SDNext here,
FromSingleFileMixin refers to download_from_original_stable_diffusion_ckpt but the former does not accept a state_dict and the later does.
While you are redesigning this workflow, would there be a clean way to enable a state_dict to be passed to FromSingleFileMixin to bypass loading from disk/hub?
Use case: Multiple state_dict objects stored in RAM for quick swap to the active pipeline.
Thanks for letting us know about it. What do you mean by the former?
@sayakpaul FromSingleFileMixin breaks the ability to pass a state_dict here https://github.com/huggingface/diffusers/blob/e0f349c2b07975810b7c4faeeafe2f4124f3cfc9/src/diffusers/loaders/single_file.py#L234
Whereas the downstream function https://github.com/huggingface/diffusers/blob/e0f349c2b07975810b7c4faeeafe2f4124f3cfc9/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py#L1134 can recieve a state_dict
I haven't fully worked through the implications but it appears that the pathing and downloading might be bypassed by if isinstance(pretrained_model_link_or_path, dict) and that keyword overrides for vae etc should still function.
Yes, that is exactly what we will do. Similar to what we do in load_lora_weights().
Keep sharing your inputs with us, it's very helpful!
Let's maybe make sure we have the same loading logic here as in load_lora
Elaboate a bit?
Same function signature names, same loading functions (also cc @DN6 here) as we talked about it. This PR is very much related btw: https://github.com/huggingface/diffusers/pull/6428
Yes, definitely. Makes sense to work on this issue after #6428 is merged?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
since #6428 is merged, perhaps good time to follow-up here?
Yes, this is on my mind. Will start working on this soon!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@yiyixuxu @sayakpaul not stale?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
No need to consider this as @DN6 has worked a great deal on making from_single_file() more and more robust. CLosing this.