Image dimension checking for ControlNet FLUX
What does this PR do?
Issue
This addresses an issue discussed in a two PRs, see https://github.com/huggingface/diffusers/pull/9406#issuecomment-2359371560 and https://github.com/huggingface/diffusers/pull/9507#issuecomment-2375884213
The FLUX controlnet pipeline is actually lacking any checks for the shape or number of control images passed (for np.ndarray or torch.Tensor and PIL objects, respectively).
I will give a simple example. If you were to run the following code:
pipe = FluxControlNetPipeline.from_pretrained(
base_model, vae=vae, controlnet=multi_controlnet, torch_dtype=torch.bfloat16
).to("cuda")
# image_t is a torch tensor of shape (2,3,h,w)
self.pipe(
prompt=["test"],
control_image=image_t,
control_mode=0,
num_images_per_prompt=1,
num_inference_steps=2
)
you'd get the following error:
Traceback (most recent call last):
File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers-tests/controlnet_pipeline_cleaner_api/flux.py", line 67, in test_torch_batched_ctrl_wrong_1ipp
self.pipe(
File "/home/beckhamc/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 742, in __call__
control_image = self._pack_latents(
File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 458, in _pack_latents
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
RuntimeError: shape '[1, 16, 32, 2, 32, 2]' is invalid for input of size 131072
This is actually because the number of control images must match the number of prompts passed -- in this case we passed in a control image of batch size 2 but the number of prompts passed is 1. Because we don't catch for this, it results in a downstream error related to the packing of the latents.
It turns out SDXL's controlnet actually checks to make sure the number of control images are consistent with the number of prompts (I do recall one of the two are also allowed to be a singleton list, which is also fine). I essentially ported over the check_image method from StableDiffusionControlNetPipeline as well as modify check_inputs to actually check the control image as well. Now if you run the above code you will get the following error instead, which makes it much clearer what the issue is:
Traceback (most recent call last):
File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers-tests/controlnet_pipeline_cleaner_api/flux.py", line 67, in test_torch_batched_ctrl_wrong_1ipp
self.pipe(
File "/home/beckhamc/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 742, in __call__
self.check_inputs(
File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 475, in check_inputs
self.check_image(image, prompt, prompt_embeds)
File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 427, in check_image
raise ValueError(
ValueError: If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: 2, prompt batch size: 1
This fix should also work for MultiControlNet, which means you can do something like this:
multi_controlnet = FluxMultiControlNetModel([controlnet] * 2)
pipe = FluxControlNetPipeline.from_pretrained(
base_model, vae=vae, controlnet=multi_controlnet, torch_dtype=torch.bfloat16
).to("cuda")
images = pipe(
prompt=["1","2","3"],
control_image=[images1, images2],
controlnet_conditioning_scale=[0.6, 0.6],
control_mode=0,
num_images_per_prompt=2
)
i.e. images and images2 are both torch.Tensor with a batch size of 3, and their corresponding ControlNet states (which will be effectively have double batch size due to num_images_per_prompt=2) will be summed together.
I have some tests you can copy and paste from here: https://github.com/christopher-beckham/diffusers-tests/blob/4b548f8/controlnet_pipeline_cleaner_api/flux.py
(you can run with python -m unittest flux.py)
Other concerns
There are some questions I have however. Why is it that we skip the image preprocessing if the image is torch.Tensor? i.e.
https://github.com/huggingface/diffusers/blob/9cd37557d581dd30fb6031ae30bd583443c3effd/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py#L526-L529
This also seems inconsistent with what is done in the SDXL ControlNet code:
https://github.com/huggingface/diffusers/blob/7071b7461b224bdc82b9dd2bde2c1842320ccc66/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L857
It may also lead to unexpected behaviour because preprocess explicitly tries to use width and height to preprocess the image (if they are None, then a reasonable default is used instead, depending on what the precise model is). But this logic gets skipped entirely if a torch.Tensor is passed.
Thanks.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [x] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings
- [x] Did you write any new necessary tests? (Yes but in my own standalone repo which I linked to )
Who can review?
@yiyixuxu @wangqixun
@yiyixuxu Thanks that is good to know. I pushed a change so that check_image now only checks the consistency for prompt and image batch size. Though it could maybe do with a more useful name... not sure what to call it, maybe check_image_and_prompt. Let me know if you think it looks good. Thanks.
bump @yiyixuxu thanks!
Thanks for the comments above @yiyixuxu
Just one last thing, there is this to take care of:
https://github.com/huggingface/diffusers/blob/9cd37557d581dd30fb6031ae30bd583443c3effd/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py#L526-L529
As I previously said, I'm not sure why all the preprocessing gets skipped for torch.Tensor -- maybe it's an oversight by the original code author -- but this is not what happens for the corresponding SDXL controlnet pipeline, which runs self.image_processor.preprocess no matter what.
Fixing this however would side effect code which already uses this class with torch.Tensor. Even if the user sets width=None and height=None in pipeline.__call__ those width and height values will internally be redefined to be 1024:
https://github.com/huggingface/diffusers/blob/9cd37557d581dd30fb6031ae30bd583443c3effd/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py#L674-L675
I made the change in the latest commit but maybe it's worth discussing this further. If we go with my commit, then maybe it's worth adding in a warning in the event that torch.Tensor is passed.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
re-bump @yiyixuxu
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.