Multi-controlnet formatting issue
Describe the bug
Hi.
There is an inconsistency between from_pretrained and save_pretrained within the Multicontrolnet class.
The from_pretrained function returns a directory structure like this: controlnet, controlnet_1, controlnet_2,
whereas save_pretrained is like this: controlnet, controlnet_1, controlnet_1_2.
When loading a saved model, if there are 3 controlnets, the last controlnet will not be loaded. (more than 2 always same issue)
Reproduction
I don't think there is no need to reproduce the code as it's pretty clear issue.
For compatibility, how about changing the save_pretrained function in Multi-ControlNet to look like the code below?
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
model_path_to_save = save_directory
for idx, controlnet in enumerate(self.nets):
suffix = "" if idx == 0 else f"_{idx}"
controlnet.save_pretrained(
model_path_to_save + suffix,
is_main_process=is_main_process,
save_function=save_function,
safe_serialization=safe_serialization,
variant=variant,
)
Logs
No response
System Info
Diffusers 0.27.2
Who can help?
@sayakpaul
Thanks for reporting.
Would you be willing to take a stab at opening a PR for this as you have a solution already?