diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Ability to change the strength of safety_checker

Open suzukimain opened this issue 1 year ago • 6 comments

What does this PR do?

Fixes #9003


"""
About safety_Level.
`int` or `float` or one of the following
'WEAK',
'MEDIUM',
'NOMAL',
'STRONG',
'MAX'.
"""

#To see the filter strength.
pipe.filter_level() # 0.0 (default)

#--------------
#If you want to change the intensity.
pipe.safety_checker_level("STRONG")
pipe.filter_level() # 1.0

#--------------
# If numbers are used
pipe.safety_checker_level(3.0)
pipe.filter_level() # 3.0

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline?
  • [ ] Did you read our philosophy doc (important for complex PRs)?
  • [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

suzukimain avatar Aug 03 '24 09:08 suzukimain

However, I could not solve the following two problems by myself.

~~1. regarding the warning message when safety_checker is weakened, the message when safety_checker=None is passed is used as is. (I was going to put in a modified message, but I was worried, so I decided it was better not to change it and copied it as is.)~~ Resolved in #9404.


~~2. for the number to replace when a string such as 'NOMAL' is entered, I have set the following, but I am not sure if this is the best value.~~ Fixed.

"WEAK": -1.0
"MEDIUM": -0.5
"NOMAL": 0.0
"STRONG": 0.5
"MAX": 1.0

Please let me know if there are other problems. thank you

suzukimain avatar Aug 03 '24 09:08 suzukimain

Hello, @yiyixuxu Is there any problem? I hope to not have caused any inconvenience. thank you for your cooperation.

suzukimain avatar Aug 22 '24 01:08 suzukimain

Hi @suzukimain what would be a use case to adjust the strength of safety_checker?

yiyixuxu avatar Oct 10 '24 18:10 yiyixuxu

Hi, @yiyixuxu

It is intended to be used in classes such as StableDiffusionPipelineSafe and StableDiffusionImg2ImgPipeline, which inherit from DiffusionPipeline and take safety_checker as an argument. Also, the module name may need to be changed to something more descriptive, although I couldn't come up with a good idea myself.

pip install git+https://github.com/suzukimain/diffusers@safety_checker

from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16
    ).to("cuda")

#To see the filter strength.
print(f"Default filter strength: {pipe.filter_level()}") # 0.0

pipe.safety_checker_level("STRONG")

print(f"Filter strength: {pipe.filter_level()}") # 0.5

img=pipe("An image of a squirrel in Picasso style").images[0]
img

suzukimain avatar Oct 11 '24 03:10 suzukimain

thanks I'm trying to understand whether it would be a common/meaningful use case that people need this feature - could you explain a little bit?

yiyixuxu avatar Oct 22 '24 00:10 yiyixuxu

thanks I'm trying to understand whether it would be a common/meaningful use case that people need this feature - could you explain a little bit?

Currently, the safety checker only allows for enabling or disabling its functionality. This means that when the filtering strength feels too weak, there's no way to increase it, and conversely, when the filtering strength feels too strong, the only option is to disable it. This can be inconvenient, so it would be beneficial to provide users with the ability to adjust the filtering strength flexibly. Additionally, it might be helpful to refer to #5623

suzukimain avatar Oct 22 '24 10:10 suzukimain

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Dec 12 '24 15:12 github-actions[bot]

IMO to have this option is not bad but I really don't understand why we added the safety checker as part of diffusers, this is better handled outside of diffusers (on the app or library) so they can use whatever they want to filter the outputs, probably this is what all the services are doing right now.

asomoza avatar Dec 12 '24 18:12 asomoza

hi @hlky, fixed. Thank you.

suzukimain avatar Dec 13 '24 00:12 suzukimain

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Hello @hlky, Some corrections have been made.

suzukimain avatar Dec 15 '24 02:12 suzukimain

I apologize for the inconvenience. I have corrected it again.

suzukimain avatar Dec 15 '24 03:12 suzukimain

@suzukimain could you run some tests to demonstrate the oversensitivity of the safety checker and find what level of adjustment works well? Choose some (safe) prompt(s), find seed(s) that trigger the safety checker, then set filter strength until the image is allowed through.

hlky avatar Dec 17 '24 15:12 hlky

Hello, and apologies for the delayed response. @hlky

The adjustment of the filtering strength has been completed. Since I had difficulty selecting the benchmark images and prompts, I generated 8000 images using Stable Diffusion-v1.5. Out of these, I took the statistics of the concept scores for around 1600 images that triggered the Safety Checker.

As a result, I believe the following values are optimal:

WEAK: -0.0690

MEDIUM: -0.0175

NORMAL: 0.0

STRONG: 0.0150

MAX: 0.0740

The criteria are as follows:

WEAK : Maximum score

MEDIUM : Average score that triggers the Safety Checker

NORMAL : Baseline, set to 0.0

STRONG : Average of all scores

MAX : Absolute value of the minimum score

Additionally, the results have been saved in this file : img_score_8000.json

Threshold Calculation Method

Get concept score of generated images

import json
import os

import numpy as np
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion.safety_checker import (
    StableDiffusionSafetyChecker,
    cosine_distance,
)
from IPython.display import display
from PIL import Image
from transformers import CLIPImageProcessor


class ScoreChecker:
    """
    Example:
        checker = ScoreChecker()
        score = checker.image_score('test.png')
    """

    def __init__(self):
        self.image_processor = VaeImageProcessor()
        self.feature_extractor = CLIPImageProcessor()
        self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
            "CompVis/stable-diffusion-safety-checker"
        ).to("cuda")

    @torch.no_grad()
    def image_score(self, image):
        converted_img = self._convert_image(image)
        feature_extractor_input = self._prepare_image(converted_img)
        safety_checker_input = self.feature_extractor(
            feature_extractor_input, return_tensors="pt"
        ).to("cuda")

        image_embeds = self._get_image_embeds(safety_checker_input)
        special_cos_dist, cos_dist = self._calculate_cosine_distances(image_embeds)

        results = [
            self._process_scores(i, special_cos_dist, cos_dist)
            for i in range(image_embeds.shape[0])
        ]
        return [self.img_score_output(result["concept_scores"]) for result in results]

    def _convert_image(self, image):
        if isinstance(image, str) and os.path.isfile(image):
            return np.array(Image.open(image))
        elif isinstance(image, Image.Image):
            return np.array(image)
        elif isinstance(image, np.ndarray):
            return image
        else:
            raise TypeError(f"Unsupported image type: {type(image)}")

    def _prepare_image(self, image):
        if torch.is_tensor(image):
            return self.image_processor.postprocess(image, output_type="pil")
        else:
            return self.image_processor.numpy_to_pil(image)

    def _get_image_embeds(self, safety_checker_input):
        pooled_output = self.safety_checker.vision_model(
            safety_checker_input.pixel_values.to(torch.float16)
        )[1]
        return self.safety_checker.visual_projection(pooled_output)

    def _calculate_cosine_distances(self, image_embeds):
        special_cos_dist = (
            cosine_distance(image_embeds, self.safety_checker.special_care_embeds)
            .cpu()
            .float()
            .numpy()
        )
        cos_dist = (
            cosine_distance(image_embeds, self.safety_checker.concept_embeds)
            .cpu()
            .float()
            .numpy()
        )
        return special_cos_dist, cos_dist

    def _process_scores(self, index, special_cos_dist, cos_dist):
        result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}}
        adjustment = 0.0

        for concept_idx in range(len(special_cos_dist[0])):
            concept_cos = special_cos_dist[index][concept_idx]
            concept_threshold = self.safety_checker.special_care_embeds_weights[
                concept_idx
            ].item()
            score = round(concept_cos - concept_threshold + adjustment, 3)
            result_img["special_scores"][concept_idx] = score
            if score > 0:
                result_img["special_care"].append({concept_idx, score})
                adjustment = 0.01

        for concept_idx in range(len(cos_dist[0])):
            concept_cos = cos_dist[index][concept_idx]
            concept_threshold = self.safety_checker.concept_embeds_weights[
                concept_idx
            ].item()
            result_img["concept_scores"][concept_idx] = round(
                concept_cos - concept_threshold + adjustment, 3
            )

        return result_img

    def img_score_output(self, scores):
        if not scores:
            return {"max": 0, "min": 0, "median": 0, "average": 0, "all": scores}

        values = list(scores.values())
        values.sort()
        return {
            "max": round(max(values), 10),
            "min": round(min(values), 10),
            "median": round(values[len(values) // 2], 10),
            "average": round(sum(values) / len(values), 10),
            "all": scores,
        }


class Generation(ScoreChecker):
    def __init__(self):
        self.limit = 8000
        self.save_file = "./img_score_8000.json"
        self.generator = torch.Generator()
        self.pipe = StableDiffusionPipeline.from_pretrained(
            "stable-diffusion-v1-5/stable-diffusion-v1-5",
            torch_dtype=torch.float16,
            safety_checker=None,
        ).to("cuda")

        self.pipe.load_textual_inversion(
            "embed/negative",
            weight_name="EasyNegativeV2.safetensors",
            token="EasyNegative",
        )
        self.pipe.requires_safety_checker = False
        self.pipe.scheduler = EulerDiscreteScheduler.from_config(
            self.pipe.scheduler.config
        )
        super().__init__()

    def run(self, base_prompt):
        scores_dict = {}
        if os.path.exists(self.save_file):
            with open(self.save_file, "r") as f:
                scores_dict = json.load(f)

        for i in range(self.limit):
            if str(i) in scores_dict:
                continue

            image = self.pipe(
                prompt=f"masterpiece, best quality, high quality, {base_prompt}",
                negative_prompt="EasyNegative",
                num_inference_steps=20,
                generator=self.generator.manual_seed(i),
            ).images[0]
            img_score = self.image_score(image)
            scores_dict[str(i)] = img_score

            with open(self.save_file, "w") as f:
                json.dump(scores_dict, f, indent=4)

            print(f"{i+1}/{self.limit}")

        return scores_dict

if __name__ == "__main__":
    generation = Generation()
    generation.run("1girl")

Score statistics

import json
import numpy as np


def calculate_statistics(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)

    positive_max_values = []
    all_max_values = []
    positive_stats = {}
    all_stats = {}

    for key, value in data.items():
        for score in value:
            all_max_values.append(score["max"])
            if score["max"] > 0:
                positive_max_values.append(score["max"])

    if positive_max_values:
        positive_stats = {
            "max": max(positive_max_values),
            "min": min(positive_max_values),
            "average": sum(positive_max_values) / len(positive_max_values),
            "median": np.median(positive_max_values),
            "quantity": len(positive_max_values),
        }

    if all_max_values:
        all_stats = {
            "max": max(all_max_values),
            "min": min(all_max_values),
            "average": sum(all_max_values) / len(all_max_values),
            "median": np.median(all_max_values),
            "quantity": len(all_max_values),
        }

    return positive_stats, all_stats


if __name__ == "__main__":
    file_path = "./img_score_8000.json"
    positive_stats, all_stats = calculate_statistics(file_path)
    print(f"Statistics of 'max' values greater than 0: {positive_stats}")
    print(f"Statistics of all 'max' values: {all_stats}")

result:

Statistics of 'max' values greater than 0:
    {'max': 0.069, 'min': 0.001, 'average': 0.01754809437386571, 'median': 0.014, 'quantity': 1653}

Statistics of all 'max' values:
    {'max': 0.069, 'min': -0.074, 'average': -0.015196250000000267, 'median': -0.018, 'quantity': 8000}

suzukimain avatar Dec 22 '24 07:12 suzukimain

Awesome research, thank you @suzukimain! The adjusted thresholds should reduce the number of false positives and make the safety checker more usable.

hlky avatar Dec 23 '24 08:12 hlky

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jan 16 '25 15:01 github-actions[bot]

Hello, @yiyixuxu Could you consider merging this PR?

suzukimain avatar Apr 10 '25 07:04 suzukimain