diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[feature] Proposal for supporting callbacks

Open a-r-r-o-w opened this issue 1 year ago • 2 comments

What does this PR do?

Fixes #7736.

This is a work-in-progress attempt at a proposal for officially supporting callbacks in diffusers. It is inspired by the design of transformers library callbacks.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@DN6 @sayakpaul

a-r-r-o-w avatar Apr 22 '24 13:04 a-r-r-o-w

Started a rough-cut implementation for this last night to try and see if callbacks similar to the style of transformers could be supported. I currently do not like the design very much because it would require modifying the __call__ of every pipeline due to how the callbacks are invoked, but I do think the emit/on_event design is a good way to go. There is no support for being able to modify the inference local variables based on results from callbacks yet, which I'll try and add soon.

from typing import Any, Dict

import torch
from diffusers import StableDiffusionPipeline
from diffusers.callbacks import BaseCallback, CallbackInput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline


class GuidanceCallback(BaseCallback):
    def __init__(self, guidance_scale: float = 7.5, change_after: float = 0.6) -> None:
        self.guidance_scale = guidance_scale
        self.change_after = change_after

    def on_step_end(self, pipe: StableDiffusionPipeline, args: CallbackInput, control: Any, **kwargs: Dict[str, Any]) -> Any:
        current_timestep = args.t.item()
        print(f"Current timestep: {current_timestep}")
        if current_timestep < int(self.change_after * pipe.scheduler.config.num_train_timesteps):
            pipe._guidance_scale = 2.5


class HelloWorldCallback(BaseCallback):
    def on_inference_begin(self, pipe: DiffusionPipeline, args: CallbackInput, control: Any, **kwargs: Dict[str, Any]) -> Any:
        print("Hello, world!")
    
    def on_inference_end(self, pipe: DiffusionPipeline, args: CallbackInput, control: Any, **kwargs: Dict[str, Any]) -> Any:
        print("Goodbye, world!")


pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
    "hf-internal-testing/tiny-stable-diffusion-pipe",
    torch_dtype=torch.float16,
    safety_checker=None,
).to("cuda")
pipe.add_callback(GuidanceCallback())
pipe.add_callback(HelloWorldCallback())

image = pipe(
    prompt="A photo of a cat",
    height=32,
    width=32,
    num_inference_steps=20,
    output_type="pil",
).images[0]

Would love to hear feedback/expectations and how this could be improved. Would also like to know what kinds of callbacks would Diffusers want to support as an offering (guidance scale, differential diffusion, etc.) @yiyixuxu @DN6 @asomoza @sayakpaul.

a-r-r-o-w avatar Apr 23 '24 04:04 a-r-r-o-w

thanks for the proposal! However, I don't think we have use cases for such a complex callback system right now

for this, @asomoza is putting together a couple of initial official callbacks we offer, and we can iterate from there. I imagine it will be a community-driven efforts moving forward

Would also like to know what kinds of callbacks would Diffusers want to support as an offering (guidance scale, differential diffusion, etc.

yiyixuxu avatar Apr 23 '24 07:04 yiyixuxu