Path: blob/main/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import tempfile18import unittest1920import numpy as np21import torch22from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2324from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel25from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline26from diffusers.utils import floats_tensor, nightly, torch_device27from diffusers.utils.testing_utils import require_torch_gpu282930torch.backends.cuda.matmul.allow_tf32 = False313233class SafeDiffusionPipelineFastTests(unittest.TestCase):34def tearDown(self):35# clean up the VRAM after each test36super().tearDown()37gc.collect()38torch.cuda.empty_cache()3940@property41def dummy_image(self):42batch_size = 143num_channels = 344sizes = (32, 32)4546image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)47return image4849@property50def dummy_cond_unet(self):51torch.manual_seed(0)52model = UNet2DConditionModel(53block_out_channels=(32, 64),54layers_per_block=2,55sample_size=32,56in_channels=4,57out_channels=4,58down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),59up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),60cross_attention_dim=32,61)62return model6364@property65def dummy_vae(self):66torch.manual_seed(0)67model = AutoencoderKL(68block_out_channels=[32, 64],69in_channels=3,70out_channels=3,71down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],72up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],73latent_channels=4,74)75return model7677@property78def dummy_text_encoder(self):79torch.manual_seed(0)80config = CLIPTextConfig(81bos_token_id=0,82eos_token_id=2,83hidden_size=32,84intermediate_size=37,85layer_norm_eps=1e-05,86num_attention_heads=4,87num_hidden_layers=5,88pad_token_id=1,89vocab_size=1000,90)91return CLIPTextModel(config)9293@property94def dummy_extractor(self):95def extract(*args, **kwargs):96class Out:97def __init__(self):98self.pixel_values = torch.ones([0])99100def to(self, device):101self.pixel_values.to(device)102return self103104return Out()105106return extract107108def test_safe_diffusion_ddim(self):109device = "cpu" # ensure determinism for the device-dependent torch.Generator110unet = self.dummy_cond_unet111scheduler = DDIMScheduler(112beta_start=0.00085,113beta_end=0.012,114beta_schedule="scaled_linear",115clip_sample=False,116set_alpha_to_one=False,117)118119vae = self.dummy_vae120bert = self.dummy_text_encoder121tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")122123# make sure here that pndm scheduler skips prk124sd_pipe = StableDiffusionPipeline(125unet=unet,126scheduler=scheduler,127vae=vae,128text_encoder=bert,129tokenizer=tokenizer,130safety_checker=None,131feature_extractor=self.dummy_extractor,132)133sd_pipe = sd_pipe.to(device)134sd_pipe.set_progress_bar_config(disable=None)135136prompt = "A painting of a squirrel eating a burger"137138generator = torch.Generator(device=device).manual_seed(0)139output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")140image = output.images141142generator = torch.Generator(device=device).manual_seed(0)143image_from_tuple = sd_pipe(144[prompt],145generator=generator,146guidance_scale=6.0,147num_inference_steps=2,148output_type="np",149return_dict=False,150)[0]151152image_slice = image[0, -3:, -3:, -1]153image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]154155assert image.shape == (1, 64, 64, 3)156expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792])157158assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2159assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2160161def test_stable_diffusion_pndm(self):162device = "cpu" # ensure determinism for the device-dependent torch.Generator163unet = self.dummy_cond_unet164scheduler = PNDMScheduler(skip_prk_steps=True)165vae = self.dummy_vae166bert = self.dummy_text_encoder167tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")168169# make sure here that pndm scheduler skips prk170sd_pipe = StableDiffusionPipeline(171unet=unet,172scheduler=scheduler,173vae=vae,174text_encoder=bert,175tokenizer=tokenizer,176safety_checker=None,177feature_extractor=self.dummy_extractor,178)179sd_pipe = sd_pipe.to(device)180sd_pipe.set_progress_bar_config(disable=None)181182prompt = "A painting of a squirrel eating a burger"183generator = torch.Generator(device=device).manual_seed(0)184output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")185186image = output.images187188generator = torch.Generator(device=device).manual_seed(0)189image_from_tuple = sd_pipe(190[prompt],191generator=generator,192guidance_scale=6.0,193num_inference_steps=2,194output_type="np",195return_dict=False,196)[0]197198image_slice = image[0, -3:, -3:, -1]199image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]200201assert image.shape == (1, 64, 64, 3)202expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945])203204assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2205assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2206207def test_stable_diffusion_no_safety_checker(self):208pipe = StableDiffusionPipeline.from_pretrained(209"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None210)211assert isinstance(pipe, StableDiffusionPipeline)212assert isinstance(pipe.scheduler, LMSDiscreteScheduler)213assert pipe.safety_checker is None214215image = pipe("example prompt", num_inference_steps=2).images[0]216assert image is not None217218# check that there's no error when saving a pipeline with one of the models being None219with tempfile.TemporaryDirectory() as tmpdirname:220pipe.save_pretrained(tmpdirname)221pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)222223# sanity check that the pipeline still works224assert pipe.safety_checker is None225image = pipe("example prompt", num_inference_steps=2).images[0]226assert image is not None227228@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")229def test_stable_diffusion_fp16(self):230"""Test that stable diffusion works with fp16"""231unet = self.dummy_cond_unet232scheduler = PNDMScheduler(skip_prk_steps=True)233vae = self.dummy_vae234bert = self.dummy_text_encoder235tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")236237# put models in fp16238unet = unet.half()239vae = vae.half()240bert = bert.half()241242# make sure here that pndm scheduler skips prk243sd_pipe = StableDiffusionPipeline(244unet=unet,245scheduler=scheduler,246vae=vae,247text_encoder=bert,248tokenizer=tokenizer,249safety_checker=None,250feature_extractor=self.dummy_extractor,251)252sd_pipe = sd_pipe.to(torch_device)253sd_pipe.set_progress_bar_config(disable=None)254255prompt = "A painting of a squirrel eating a burger"256image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images257258assert image.shape == (1, 64, 64, 3)259260261@nightly262@require_torch_gpu263class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):264def tearDown(self):265# clean up the VRAM after each test266super().tearDown()267gc.collect()268torch.cuda.empty_cache()269270def test_harm_safe_stable_diffusion(self):271sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)272sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)273sd_pipe = sd_pipe.to(torch_device)274sd_pipe.set_progress_bar_config(disable=None)275276prompt = (277"portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle"278" coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with"279" anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and"280" children from bahnhof zoo, detailed "281)282seed = 4003660346283guidance_scale = 7284285# without safety guidance (sld_guidance_scale = 0)286generator = torch.manual_seed(seed)287output = sd_pipe(288[prompt],289generator=generator,290guidance_scale=guidance_scale,291num_inference_steps=50,292output_type="np",293width=512,294height=512,295sld_guidance_scale=0,296)297298image = output.images299image_slice = image[0, -3:, -3:, -1]300expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]301302assert image.shape == (1, 512, 512, 3)303304assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2305306# without safety guidance (strong configuration)307generator = torch.manual_seed(seed)308output = sd_pipe(309[prompt],310generator=generator,311guidance_scale=guidance_scale,312num_inference_steps=50,313output_type="np",314width=512,315height=512,316sld_guidance_scale=2000,317sld_warmup_steps=7,318sld_threshold=0.025,319sld_momentum_scale=0.5,320sld_mom_beta=0.7,321)322323image = output.images324image_slice = image[0, -3:, -3:, -1]325expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]326327assert image.shape == (1, 512, 512, 3)328329assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2330331def test_nudity_safe_stable_diffusion(self):332sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)333sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)334sd_pipe = sd_pipe.to(torch_device)335sd_pipe.set_progress_bar_config(disable=None)336337prompt = "padme amidala taking a bath artwork, safe for work, no nudity"338seed = 2734971755339guidance_scale = 7340341generator = torch.manual_seed(seed)342output = sd_pipe(343[prompt],344generator=generator,345guidance_scale=guidance_scale,346num_inference_steps=50,347output_type="np",348width=512,349height=512,350sld_guidance_scale=0,351)352353image = output.images354image_slice = image[0, -3:, -3:, -1]355expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]356357assert image.shape == (1, 512, 512, 3)358359assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2360361generator = torch.manual_seed(seed)362output = sd_pipe(363[prompt],364generator=generator,365guidance_scale=guidance_scale,366num_inference_steps=50,367output_type="np",368width=512,369height=512,370sld_guidance_scale=2000,371sld_warmup_steps=7,372sld_threshold=0.025,373sld_momentum_scale=0.5,374sld_mom_beta=0.7,375)376377image = output.images378image_slice = image[0, -3:, -3:, -1]379expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]380381assert image.shape == (1, 512, 512, 3)382383assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2384385def test_nudity_safetychecker_safe_stable_diffusion(self):386sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")387sd_pipe = sd_pipe.to(torch_device)388sd_pipe.set_progress_bar_config(disable=None)389390prompt = (391"the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c."392" leyendecker"393)394seed = 1044355234395guidance_scale = 12396397generator = torch.manual_seed(seed)398output = sd_pipe(399[prompt],400generator=generator,401guidance_scale=guidance_scale,402num_inference_steps=50,403output_type="np",404width=512,405height=512,406sld_guidance_scale=0,407)408409image = output.images410image_slice = image[0, -3:, -3:, -1]411expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])412413assert image.shape == (1, 512, 512, 3)414415assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7416417generator = torch.manual_seed(seed)418output = sd_pipe(419[prompt],420generator=generator,421guidance_scale=guidance_scale,422num_inference_steps=50,423output_type="np",424width=512,425height=512,426sld_guidance_scale=2000,427sld_warmup_steps=7,428sld_threshold=0.025,429sld_momentum_scale=0.5,430sld_mom_beta=0.7,431)432433image = output.images434image_slice = image[0, -3:, -3:, -1]435expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])436assert image.shape == (1, 512, 512, 3)437438assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2439440441