[LoRA] introduce LoraBaseMixin to promote reusability.
What does this PR do?
It is basically a mirror of https://github.com/huggingface/diffusers/pull/8670. I had accidentally merged it but I have reverted it in https://github.com/huggingface/diffusers/pull/8773. Apologies for this.
I have made comments in line to address the questions brought up by @yiyixuxu.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Nice initiative 👍🏽 . A lot to unpack here, so perhaps it's best to start bit by bit. I just went over the pipeline related components here.
Regarding the LoraBaseMixin, at the moment I think it might be doing a bit too much.
There are quite a few methods in there that are making assumptions about the inheriting class using the method, which isn't really how a base class should behave. So loading methods related to specific model components are better left out e.g. load_lora_into_text_encoder. If this method is used across different pipelines with no changes, then it's better to create a utility function that does this and call it from the inheriting class. Or redefine the method in the inheriting class and use copied from.
I would assume that these are the methods that need to be defined for managing LoRAs across all pipelines?
class LoraBaseMixin:
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
raise NotImplementedError()
@classmethod
def _fetch_state_dict(
cls,
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
resume_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
raise NotImplementedError()
@classmethod
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
return NotImplementedError()
@classmethod
def save_lora_weights(cls, **kwargs):
raise NotImplementedError("`save_lora_weights()` not implemented.")
@classmethod
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
def unload_lora_weights(self, **kwargs):
raise NotImplementedError("`unload_lora_weights()` is not implemented.")
def fuse_lora(self, **kwargs):
raise NotImplementedError("`fuse_lora()` is not implemented.")
def unfuse_lora(self, **kwargs):
raise NotImplementedError("`unfuse_lora()` is not implemented.")
def disable_lora(self):
raise NotImplementedError("`disable_lora()` is not implemented.")
def enable_lora(self):
raise NotImplementedError("`unfuse_lora()` is not implemented.")
def get_active_adapters(self):
raise NotImplementedError("`delete_adapters()` is not implemented.")
def delete_adapters(self, adapter_names):
raise NotImplementedError("`delete_adapters()` is not implemented.")
def set_lora_device(self, adapter_names):
raise NotImplementedError("`delete_adapters()` is not implemented.")
@staticmethod
def pack_weights(layers, prefix):
raise NotImplementedError()
@staticmethod
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
raise NotImplementedError()
@property
def lora_scale(self) -> float:
raise NotImplementedError()
Quite a few of these methods probably cannot be defined in the base class, such as load_lora_weights and unload_lora_weights, fuse_lora and unfuse_lora, since they deal with specific pipeline components
They might also require arguments specific to the pipeline type or pipeline components.
I think it might be better to define these methods in a pipeline specific class that inherits from the LoraBaseMixin. Or just as it's own Mixin class. I don't have a strong feeling about either approach. e.g. StableDiffusionLoraLoaderMixin could look like:
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
_lora_loadable_modules = ["unet", "text_encoder"]
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
**kwargs,
):
_load_lora_into_unet(**kwargs)
_load_lora_into_text_encoder(**kwargs)
def fuse_lora(self, components=["unet", "text_encoder"], **kwargs):
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError()
model = getattr(self, fuse_component)
# check if diffusers model
if issubclass(model, ModelMixin):
model.fuse_lora()
# handle transformers models.
if issubclass(model, PretrainedModel):
fuse_text_encoder()
I saw this comment about using the term "fuse_denoiser" in the fusing methods. I'm not so sure about that. I think if we want to fuse the LoRA in a specific component, it's better to pass in the actual name of the component used in pipeline, rather than track another attribute such as denoiser
I also think the constants and class attributes such as TEXT_ENCODER_NAME and is_unet_denoiser might not be needed if we use a single class attribute with a list of the names of the lora loadable components.
There are quite a few methods in there that are making assumptions about the inheriting class using the method, which isn't really how a base class should behave. So loading methods related to specific model components are better left out e.g. load_lora_into_text_encoder. If this method is used across different pipelines with no changes, then it's better to create a utility function that does this and call it from the inheriting class. Or redefine the method in the inheriting class and use copied from.
This is better. I will go with that.
Quite a few of these methods probably cannot be defined in the base class, such as load_lora_weights and unload_lora_weights, fuse_lora and unfuse_lora, since they deal with specific pipeline components They might also require arguments specific to the pipeline type or pipeline components.
I think that is not the case and that should be evident by the way the children classes have been implemented. They differ on the loading logic mostly which is why all the children classes implement their own loading logic. Fusion and unloading can be independent of that.
I don't have a strong feeling about either approach. e.g. StableDiffusionLoraLoaderMixin could look like:
I can refactor the components argument you're suggesting for fuse_lora(). I like the idea. So for
I think if we want to fuse the LoRA in a specific component, it's better to pass in the actual name of the component used in pipeline, rather than track another attribute such as denoiser
we need to deprecate fuse_unet and fuse_transformer nicely, which is already being done.
I also think the constants and class attributes such as TEXT_ENCODER_NAME and is_unet_denoiser might not be needed if we use a single class attribute with a list of the names of the lora loadable components.
Right. But not sure about TEXT_ENCODER_NAME since we use those for serializing the weights. We could just maintain a mapping on top of the module to do that?
So:
{
"unet": UNET_NAME,
"transformer": TRANSFORMER_NAME,
...
}
Then using this mapping should be easy to use across.
Edit: Decided not to go with the above because the code was getting complex for a seemingly simple task. Consider this line of code: https://github.com/huggingface/diffusers/blob/06ee4db3e7a5342871404ae445cf71665bc6a580/src/diffusers/loaders/lora.py#L398
Without a cls.text_encoder_name it will be hard to determine the only_text_encoder variable relatively easily. Similarly here as well: https://github.com/huggingface/diffusers/blob/06ee4db3e7a5342871404ae445cf71665bc6a580/src/diffusers/loaders/lora.py#L673
@DN6 apart from
I also think the constants and class attributes such as TEXT_ENCODER_NAME and is_unet_denoiser might not be needed if we use a single class attribute with a list of the names of the lora loadable components.
I was able to remove is_unet_denoiser, etc. successfully but couldn't do that for TEXT_ENCODER_NAME, UNET_NAME, and TRANSFORMER_NAME because of the reasons explained the last part of https://github.com/huggingface/diffusers/pull/8774#issuecomment-2205488122
God this timeout issue is getting out of hands >_<
@DN6 as discussed over Slack, I have unified the PeftAdapterMixin class too so that we can have methods like fuse_lora(), delete_lora(), enable_lora(), etc. under one umbrella without having to define and copy-paste them for each model-specific loader mixins such as UNet2DConditionLoadersMixin.
One thing to note is that I had to still keep loaders/transformer_sd3.py to implement set_adapters() as this method varies from unet to transformer. This is because the block naming is different in these models. This is why you will also see set_adapters() in UNet2DConditionLoadersMixin.
We could have two additional classes under loaders/peft.py:
-
TransformerPeftAdapterMixin(PeftAdapterMixin) -
UNet2DConditionPeftAdapterMixin(PeftAdapterMixin)to reimplement this method there and use them accordingly.
LMK.
@DN6 I think this is ready for another review now.
@DN6 anything else you would like me to address?
Thanks for the massive help and guidance, Dhruv!