Type hint for `callback_on_step_end` in pipeline `__call__` is incorrect
Describe the bug
As far as I can tell, the type hint callback_on_step_end for all pipelines (as seen in this search) is incorrect.
Taking the SDXL pipeline as an example, the type hint in __call__ is:
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
but it is called like this:
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
That is, the type hint suggests that it's a function with three arguments that returns nothing, but it's actually a function with four arguments that returns a Dict. This ends up failing at runtime. Using a four-argument function leads to type-checker errors.
Reproduction
Using a Python typechecker (Pyright in my case), attempt to use a correctly-defined callback with a SDXL pipeline:
sdxl_pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16"
)
sdxl_pipe = typing.cast(StableDiffusionXLPipeline, sdxl_pipe)
def callback_on_step_end(_pipe, step, _timestep, kwargs):
print(step)
return kwargs
output = sdxl_pipe(
prompt="a prompt",
negative_prompt=self.negative_prompt,
num_inference_steps=4,
guidance_scale=0.0,
return_dict=True,
width=512,
height=512,
callback_on_step_end=callback_on_step_end,
)
The type-checker will error on callback_on_step_end:
Argument of type "(_pipe: Unknown, step: Unknown, _timestep: Unknown, kwargs: Unknown) -> Unknown" cannot be assigned to parameter "callback_on_step_end" of type "((int, int, Dict[Unknown, Unknown]) -> None) | None" in function "__call__"
Type "(_pipe: Unknown, step: Unknown, _timestep: Unknown, kwargs: Unknown) -> Unknown" cannot be assigned to type "((int, int, Dict[Unknown, Unknown]) -> None) | None"
Type "(_pipe: Unknown, step: Unknown, _timestep: Unknown, kwargs: Unknown) -> Unknown" cannot be assigned to type "(int, int, Dict[Unknown, Unknown]) -> None"
Function accepts too few positional parameters; expected 4 but received 3
"function" is incompatible with "None"
This can be worked around using # type: ignore, which is what I'm doing.
Logs
No response
System Info
N/A
Who can help?
No response
Thanks for creating this issue do you want to open a PR to fix the type hint?
~~Hi, sorry for the late response! Yes, I can take a look at it; should hopefully be a relatively straightforward fix.~~
Edit: I'm not currently working with diffusers, so I haven't been able to work on this fix. Free for anyone else to take it.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.