diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

feat(diffusers): add 'safety_check' pipeline argument

Open rickstaa opened this issue 1 year ago • 0 comments

What does this PR do?

This pull request introduces the safety_check argument to the call method of the StableDiffusionPipeline. This new argument provides users with the flexibility to dynamically enable or disable safety checks during a pipeline execution. The primary motivation for this feature is to give users the option to filter NSFW content when generating images, depending on their specific needs.

Alternatives Considered

I'm currently using a code snippet provided on the forum to dynamically toggle the safety check, but integrating this functionality directly into the diffusion pipeline would streamline my code and could also benefit others.

from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
import numpy as np
import torch
from PIL import Image 
from typing import Optional, Tuple, Union

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_device = device
torch_dtype = torch.float16

safety_checker = StableDiffusionSafetyChecker.from_pretrained(
    "CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
    "openai/clip-vit-base-patch32"
)

def check_nsfw_images(
    images: list[Image.Image],
    output_type: str | None = "pil"
) -> tuple[list[Image.Image], list[bool]]:
    safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
    images_np = [np.array(img) for img in images]

    _, has_nsfw_concepts = safety_checker(
        images=images_np,
        clip_input=safety_checker_input.pixel_values.to(torch_device),
    )
    if output_type == "pil":
      return images, has_nsfw_concepts
    return images_np, has_nsfw_concepts

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline?
  • [x] Did you read our philosophy doc (important for complex PRs)?
  • [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case. - @rickstaa no.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings. - @rickstaa I went through all doc pages that mentioned StableDiffusionPipeline but did not see any pages I should edit.
  • [ ] Did you write any new necessary tests? - @ricksaa I checked https://github.com/rickstaa/diffusers/blob/801484840a3fa71285a7e096cc07f93b1ae681b7/tests/pipelines/stable_diffusion/test_stable_diffusion.py but did not see argument tests. Some guidance on whether we should add a test for the new safety_check would be much appreciated.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

rickstaa avatar May 05 '24 15:05 rickstaa