diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Multi-controlnet formatting issue

Open rebel-kblee opened this issue 1 year ago • 1 comments

Describe the bug

Hi. There is an inconsistency between from_pretrained and save_pretrained within the Multicontrolnet class. The from_pretrained function returns a directory structure like this: controlnet, controlnet_1, controlnet_2, whereas save_pretrained is like this: controlnet, controlnet_1, controlnet_1_2. When loading a saved model, if there are 3 controlnets, the last controlnet will not be loaded. (more than 2 always same issue)

Reproduction

I don't think there is no need to reproduce the code as it's pretty clear issue. For compatibility, how about changing the save_pretrained function in Multi-ControlNet to look like the code below?

def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        is_main_process: bool = True,
        save_function: Callable = None,
        safe_serialization: bool = True,
        variant: Optional[str] = None,
    ):
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
        `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.

        Arguments:
            save_directory (`str` or `os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
            is_main_process (`bool`, *optional*, defaults to `True`):
                Whether the process calling this is the main process or not. Useful when in distributed training like
                TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
                the main process to avoid race conditions.
            save_function (`Callable`):
                The function to use to save the state dictionary. Useful on distributed training like TPUs when one
                need to replace `torch.save` by another method. Can be configured with the environment variable
                `DIFFUSERS_SAVE_MODE`.
            safe_serialization (`bool`, *optional*, defaults to `True`):
                Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
            variant (`str`, *optional*):
                If specified, weights are saved in the format pytorch_model.<variant>.bin.
        """
        model_path_to_save = save_directory
        for idx, controlnet in enumerate(self.nets):
            suffix = "" if idx == 0 else f"_{idx}"
            controlnet.save_pretrained(
                model_path_to_save + suffix,
                is_main_process=is_main_process,
                save_function=save_function,
                safe_serialization=safe_serialization,
                variant=variant,
            )

Logs

No response

System Info

Diffusers 0.27.2

Who can help?

@sayakpaul

rebel-kblee avatar Apr 29 '24 05:04 rebel-kblee

Thanks for reporting.

Would you be willing to take a stab at opening a PR for this as you have a solution already?

sayakpaul avatar Apr 29 '24 06:04 sayakpaul