[feature] Proposal for supporting callbacks
What does this PR do?
Fixes #7736.
This is a work-in-progress attempt at a proposal for officially supporting callbacks in diffusers. It is inspired by the design of transformers library callbacks.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@DN6 @sayakpaul
Started a rough-cut implementation for this last night to try and see if callbacks similar to the style of transformers could be supported. I currently do not like the design very much because it would require modifying the __call__ of every pipeline due to how the callbacks are invoked, but I do think the emit/on_event design is a good way to go. There is no support for being able to modify the inference local variables based on results from callbacks yet, which I'll try and add soon.
from typing import Any, Dict
import torch
from diffusers import StableDiffusionPipeline
from diffusers.callbacks import BaseCallback, CallbackInput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
class GuidanceCallback(BaseCallback):
def __init__(self, guidance_scale: float = 7.5, change_after: float = 0.6) -> None:
self.guidance_scale = guidance_scale
self.change_after = change_after
def on_step_end(self, pipe: StableDiffusionPipeline, args: CallbackInput, control: Any, **kwargs: Dict[str, Any]) -> Any:
current_timestep = args.t.item()
print(f"Current timestep: {current_timestep}")
if current_timestep < int(self.change_after * pipe.scheduler.config.num_train_timesteps):
pipe._guidance_scale = 2.5
class HelloWorldCallback(BaseCallback):
def on_inference_begin(self, pipe: DiffusionPipeline, args: CallbackInput, control: Any, **kwargs: Dict[str, Any]) -> Any:
print("Hello, world!")
def on_inference_end(self, pipe: DiffusionPipeline, args: CallbackInput, control: Any, **kwargs: Dict[str, Any]) -> Any:
print("Goodbye, world!")
pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe",
torch_dtype=torch.float16,
safety_checker=None,
).to("cuda")
pipe.add_callback(GuidanceCallback())
pipe.add_callback(HelloWorldCallback())
image = pipe(
prompt="A photo of a cat",
height=32,
width=32,
num_inference_steps=20,
output_type="pil",
).images[0]
Would love to hear feedback/expectations and how this could be improved. Would also like to know what kinds of callbacks would Diffusers want to support as an offering (guidance scale, differential diffusion, etc.) @yiyixuxu @DN6 @asomoza @sayakpaul.
thanks for the proposal! However, I don't think we have use cases for such a complex callback system right now
for this, @asomoza is putting together a couple of initial official callbacks we offer, and we can iterate from there. I imagine it will be a community-driven efforts moving forward
Would also like to know what kinds of callbacks would Diffusers want to support as an offering (guidance scale, differential diffusion, etc.