Add pruna integration for loading model through diffusers.from_pretrained / pipeline
Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...].
Pruna is an open-source AI model optimisation framework.
As discussed with @SunMarc about doing this for transformers but we could do something similar for diffusers as well. Something like *Pipeline.from_pretrained interface as an alternative to the PrunaModel interface.
Currently, the code looks as follows.
Currently, the code looks as follows.
from pruna import PrunaModel
loaded_model = PrunaModel.from_hub(
"PrunaAI/FLUX.1-dev-smashed"
)
We could go for something like.
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("PrunaAI/FLUX.1-dev-smashed")
Describe the solution you'd like. It would be a nice integration.
Describe alternatives you've considered. We could also add the library as an explicit tab selector within the Hub, similar to llama-cpp/unsloth and other frameworks.
Additional context. NA
cc @sayakpaul here I think you just need to register with hub to add a tab selector, no?
Yeah that is what my feeling is.
Hi @yiyixuxu @sayakpaul ,
I think we can implement one or multiple of the approaches underneath
Tab Selector: With tag==pruna-ai
For example: https://huggingface.co/PrunaAI/FLUX.1-schnell-smashed
Proposed snippets
!pip install pruna[full] # https://docs.pruna.ai/en/stable/setup/install.html
from pruna import PrunaModel
# Load model from Hub
pipe = PrunaModel.from_hub(
"PrunaAI/FLUX.1-schnell-smashed"
)
# Run inference
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = optimized_pipe(prompt).images[0]
Tab Selector: With library==diffusers
For example: https://huggingface.co/black-forest-labs/FLUX.1-schnell
Proposed snippets
!pip install pruna[full] # https://docs.pruna.ai/en/stable/setup/install.html
from pruna import smash, SmashConfig
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
# Create and configure the optimisation config
# https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html
smash_config = SmashConfig()
smash_config["compiler"] = "torch_compile" # can always be applied
# Optimise your model
optimized_pipe = smash(model=model, smash_config=smash_config)
# Run inference
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = optimized_pipe(prompt).images[0]
Adding optimization_config to from_pretrained and passing a config
Like shown here: https://huggingface.co/docs/diffusers/main/en/quantization/quanto
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
from pruna import SmashConfig
smash_config = SmashConfig()
smash_config["compiler"] = "torch_compile" # can always be applied
model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
optimization_config=smash_config,
torch_dtype=torch.bfloat16,
)
However, instead of a quantization_config we could opt for implementing something that relies on a more general optimization_config.
Hidding complexities and handling with from_pretrained internally
Consider the following model: https://huggingface.co/PrunaAI/FLUX.1-schnell-smashed
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("PrunaAI/FLUX.1-schnell-smashed")
Where we would check for a smash_config.json file in the repository, and use that to re-apply all algorithms in the background.
"Tab Selector: With library==diffusers" is better, IMO. If we see traction, we can opt for more integrations. @yiyixuxu @DN6 WDYT?
Lovely! @sayakpaul, besides the code snippet above, can we prepare a PR from our side? I think this needs to be done internally on the HF side, right? We directly wrap around pipelines or models, so besides the optimised variant from Pruna, it should be able to be used interchangeably with the ones in the diffusers code snippets.
Hi @sayakpaul @yiyixuxu, just a brief reminder.
I am not sure if there's a TODO on our side?
Can I open a PR for the tab selector @yiyixuxu @DN6 or should that be done on your side?
@yiyixuxu @DN6 reminder.
@davidberenstein1957 I believe the tab selector would have to be added to the metadata in the model card in the PrunaAI repos? Is there something else you had in mind here?
Hi @DN6, yes, I could not find any documentation on how to add this to a model card, so I assumed it would need to be handled from your side. Could you provide me with an example to show how to do this? I also took a look at this repo, which open a tab selector to the model2vec library but it does not show any configs in the readme.
This seems like somethinh related to the Hub platform. Maybe @Wauplin has more pointers?
I'm not sure what you are all referring by "tab selector" in this thread to be honest.
If you want to register "pruna" as a library and expose code snippets dedicated to these models, you can do that by following the instructions in this guide: https://huggingface.co/docs/hub/models-adding-libraries?
If this is not what you are looking for, please let me know
Thanks @Wauplin, I did not realise this was something we needed to do on our side. I opened a PR here :)
https://github.com/huggingface/huggingface.js/pull/1684/files
Closing the issue for now.