feat(diffusers): add 'safety_check' pipeline argument
What does this PR do?
This pull request introduces the safety_check argument to the call method of the StableDiffusionPipeline. This new argument provides users with the flexibility to dynamically enable or disable safety checks during a pipeline execution. The primary motivation for this feature is to give users the option to filter NSFW content when generating images, depending on their specific needs.
Alternatives Considered
I'm currently using a code snippet provided on the forum to dynamically toggle the safety check, but integrating this functionality directly into the diffusion pipeline would streamline my code and could also benefit others.
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
import numpy as np
import torch
from PIL import Image
from typing import Optional, Tuple, Union
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_device = device
torch_dtype = torch.float16
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: list[Image.Image],
output_type: str | None = "pil"
) -> tuple[list[Image.Image], list[bool]]:
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
images_np = [np.array(img) for img in images]
_, has_nsfw_concepts = safety_checker(
images=images_np,
clip_input=safety_checker_input.pixel_values.to(torch_device),
)
if output_type == "pil":
return images, has_nsfw_concepts
return images_np, has_nsfw_concepts
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case. - @rickstaa no.
- [x] Did you make sure to update the documentation with your changes? Here are the
documentation guidelines, and
here are tips on formatting docstrings. - @rickstaa I went through all doc pages that mentioned
StableDiffusionPipelinebut did not see any pages I should edit. - [ ] Did you write any new necessary tests? - @ricksaa I checked https://github.com/rickstaa/diffusers/blob/801484840a3fa71285a7e096cc07f93b1ae681b7/tests/pipelines/stable_diffusion/test_stable_diffusion.py but did not see argument tests. Some guidance on whether we should add a test for the new
safety_checkwould be much appreciated.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.