Path: blob/main/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
1450 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import tempfile18import unittest1920import numpy as np21import torch22from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2324from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel25from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipeline as StableDiffusionPipeline26from diffusers.utils import floats_tensor, nightly, torch_device27from diffusers.utils.testing_utils import require_torch_gpu282930torch.backends.cuda.matmul.allow_tf32 = False313233class SafeDiffusionPipelineFastTests(unittest.TestCase):34def tearDown(self):35# clean up the VRAM after each test36super().tearDown()37gc.collect()38torch.cuda.empty_cache()3940@property41def dummy_image(self):42batch_size = 143num_channels = 344sizes = (32, 32)4546image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)47return image4849@property50def dummy_cond_unet(self):51torch.manual_seed(0)52model = UNet2DConditionModel(53block_out_channels=(32, 64),54layers_per_block=2,55sample_size=32,56in_channels=4,57out_channels=4,58down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),59up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),60cross_attention_dim=32,61)62return model6364@property65def dummy_vae(self):66torch.manual_seed(0)67model = AutoencoderKL(68block_out_channels=[32, 64],69in_channels=3,70out_channels=3,71down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],72up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],73latent_channels=4,74)75return model7677@property78def dummy_text_encoder(self):79torch.manual_seed(0)80config = CLIPTextConfig(81bos_token_id=0,82eos_token_id=2,83hidden_size=32,84intermediate_size=37,85layer_norm_eps=1e-05,86num_attention_heads=4,87num_hidden_layers=5,88pad_token_id=1,89vocab_size=1000,90)91return CLIPTextModel(config)9293@property94def dummy_extractor(self):95def extract(*args, **kwargs):96class Out:97def __init__(self):98self.pixel_values = torch.ones([0])99100def to(self, device):101self.pixel_values.to(device)102return self103104return Out()105106return extract107108def test_semantic_diffusion_ddim(self):109device = "cpu" # ensure determinism for the device-dependent torch.Generator110unet = self.dummy_cond_unet111scheduler = DDIMScheduler(112beta_start=0.00085,113beta_end=0.012,114beta_schedule="scaled_linear",115clip_sample=False,116set_alpha_to_one=False,117)118119vae = self.dummy_vae120bert = self.dummy_text_encoder121tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")122123# make sure here that pndm scheduler skips prk124sd_pipe = StableDiffusionPipeline(125unet=unet,126scheduler=scheduler,127vae=vae,128text_encoder=bert,129tokenizer=tokenizer,130safety_checker=None,131feature_extractor=self.dummy_extractor,132)133sd_pipe = sd_pipe.to(device)134sd_pipe.set_progress_bar_config(disable=None)135136prompt = "A painting of a squirrel eating a burger"137138generator = torch.Generator(device=device).manual_seed(0)139output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")140image = output.images141142generator = torch.Generator(device=device).manual_seed(0)143image_from_tuple = sd_pipe(144[prompt],145generator=generator,146guidance_scale=6.0,147num_inference_steps=2,148output_type="np",149return_dict=False,150)[0]151152image_slice = image[0, -3:, -3:, -1]153image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]154155assert image.shape == (1, 64, 64, 3)156expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792])157158assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2159assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2160161def test_semantic_diffusion_pndm(self):162device = "cpu" # ensure determinism for the device-dependent torch.Generator163unet = self.dummy_cond_unet164scheduler = PNDMScheduler(skip_prk_steps=True)165vae = self.dummy_vae166bert = self.dummy_text_encoder167tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")168169# make sure here that pndm scheduler skips prk170sd_pipe = StableDiffusionPipeline(171unet=unet,172scheduler=scheduler,173vae=vae,174text_encoder=bert,175tokenizer=tokenizer,176safety_checker=None,177feature_extractor=self.dummy_extractor,178)179sd_pipe = sd_pipe.to(device)180sd_pipe.set_progress_bar_config(disable=None)181182prompt = "A painting of a squirrel eating a burger"183generator = torch.Generator(device=device).manual_seed(0)184output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")185186image = output.images187188generator = torch.Generator(device=device).manual_seed(0)189image_from_tuple = sd_pipe(190[prompt],191generator=generator,192guidance_scale=6.0,193num_inference_steps=2,194output_type="np",195return_dict=False,196)[0]197198image_slice = image[0, -3:, -3:, -1]199image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]200201assert image.shape == (1, 64, 64, 3)202expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945])203204assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2205assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2206207def test_semantic_diffusion_no_safety_checker(self):208pipe = StableDiffusionPipeline.from_pretrained(209"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None210)211assert isinstance(pipe, StableDiffusionPipeline)212assert isinstance(pipe.scheduler, LMSDiscreteScheduler)213assert pipe.safety_checker is None214215image = pipe("example prompt", num_inference_steps=2).images[0]216assert image is not None217218# check that there's no error when saving a pipeline with one of the models being None219with tempfile.TemporaryDirectory() as tmpdirname:220pipe.save_pretrained(tmpdirname)221pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)222223# sanity check that the pipeline still works224assert pipe.safety_checker is None225image = pipe("example prompt", num_inference_steps=2).images[0]226assert image is not None227228@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")229def test_semantic_diffusion_fp16(self):230"""Test that stable diffusion works with fp16"""231unet = self.dummy_cond_unet232scheduler = PNDMScheduler(skip_prk_steps=True)233vae = self.dummy_vae234bert = self.dummy_text_encoder235tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")236237# put models in fp16238unet = unet.half()239vae = vae.half()240bert = bert.half()241242# make sure here that pndm scheduler skips prk243sd_pipe = StableDiffusionPipeline(244unet=unet,245scheduler=scheduler,246vae=vae,247text_encoder=bert,248tokenizer=tokenizer,249safety_checker=None,250feature_extractor=self.dummy_extractor,251)252sd_pipe = sd_pipe.to(torch_device)253sd_pipe.set_progress_bar_config(disable=None)254255prompt = "A painting of a squirrel eating a burger"256image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images257258assert image.shape == (1, 64, 64, 3)259260261@nightly262@require_torch_gpu263class SemanticDiffusionPipelineIntegrationTests(unittest.TestCase):264def tearDown(self):265# clean up the VRAM after each test266super().tearDown()267gc.collect()268torch.cuda.empty_cache()269270def test_positive_guidance(self):271torch_device = "cuda"272pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")273pipe = pipe.to(torch_device)274pipe.set_progress_bar_config(disable=None)275276prompt = "a photo of a cat"277edit = {278"editing_prompt": ["sunglasses"],279"reverse_editing_direction": [False],280"edit_warmup_steps": 10,281"edit_guidance_scale": 6,282"edit_threshold": 0.95,283"edit_momentum_scale": 0.5,284"edit_mom_beta": 0.6,285}286287seed = 3288guidance_scale = 7289290# no sega enabled291generator = torch.Generator(torch_device)292generator.manual_seed(seed)293output = pipe(294[prompt],295generator=generator,296guidance_scale=guidance_scale,297num_inference_steps=50,298output_type="np",299width=512,300height=512,301)302303image = output.images304image_slice = image[0, -3:, -3:, -1]305expected_slice = [3060.34673113,3070.38492733,3080.37597352,3090.34086335,3100.35650748,3110.35579205,3120.3384763,3130.34340236,3140.3573271,315]316317assert image.shape == (1, 512, 512, 3)318319assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2320321# with sega enabled322# generator = torch.manual_seed(seed)323generator.manual_seed(seed)324output = pipe(325[prompt],326generator=generator,327guidance_scale=guidance_scale,328num_inference_steps=50,329output_type="np",330width=512,331height=512,332**edit,333)334335image = output.images336image_slice = image[0, -3:, -3:, -1]337expected_slice = [3380.41887826,3390.37728766,3400.30138272,3410.41416335,3420.41664985,3430.36283392,3440.36191246,3450.43364465,3460.43001732,347]348349assert image.shape == (1, 512, 512, 3)350351assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2352353def test_negative_guidance(self):354torch_device = "cuda"355pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")356pipe = pipe.to(torch_device)357pipe.set_progress_bar_config(disable=None)358359prompt = "an image of a crowded boulevard, realistic, 4k"360edit = {361"editing_prompt": "crowd, crowded, people",362"reverse_editing_direction": True,363"edit_warmup_steps": 10,364"edit_guidance_scale": 8.3,365"edit_threshold": 0.9,366"edit_momentum_scale": 0.5,367"edit_mom_beta": 0.6,368}369370seed = 9371guidance_scale = 7372373# no sega enabled374generator = torch.Generator(torch_device)375generator.manual_seed(seed)376output = pipe(377[prompt],378generator=generator,379guidance_scale=guidance_scale,380num_inference_steps=50,381output_type="np",382width=512,383height=512,384)385386image = output.images387image_slice = image[0, -3:, -3:, -1]388expected_slice = [3890.43497998,3900.91814065,3910.7540739,3920.55580205,3930.8467265,3940.5389691,3950.62574506,3960.58897763,3970.50926757,398]399400assert image.shape == (1, 512, 512, 3)401402assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2403404# with sega enabled405# generator = torch.manual_seed(seed)406generator.manual_seed(seed)407output = pipe(408[prompt],409generator=generator,410guidance_scale=guidance_scale,411num_inference_steps=50,412output_type="np",413width=512,414height=512,415**edit,416)417418image = output.images419image_slice = image[0, -3:, -3:, -1]420expected_slice = [4210.3089719,4220.30500144,4230.29016042,4240.30630964,4250.325687,4260.29419225,4270.2908091,4280.28723598,4290.27696294,430]431432assert image.shape == (1, 512, 512, 3)433434assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2435436def test_multi_cond_guidance(self):437torch_device = "cuda"438pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")439pipe = pipe.to(torch_device)440pipe.set_progress_bar_config(disable=None)441442prompt = "a castle next to a river"443edit = {444"editing_prompt": ["boat on a river, boat", "monet, impression, sunrise"],445"reverse_editing_direction": False,446"edit_warmup_steps": [15, 18],447"edit_guidance_scale": 6,448"edit_threshold": [0.9, 0.8],449"edit_momentum_scale": 0.5,450"edit_mom_beta": 0.6,451}452453seed = 48454guidance_scale = 7455456# no sega enabled457generator = torch.Generator(torch_device)458generator.manual_seed(seed)459output = pipe(460[prompt],461generator=generator,462guidance_scale=guidance_scale,463num_inference_steps=50,464output_type="np",465width=512,466height=512,467)468469image = output.images470image_slice = image[0, -3:, -3:, -1]471expected_slice = [4720.75163555,4730.76037145,4740.61785,4750.9189673,4760.8627701,4770.85189694,4780.8512813,4790.87012076,4800.8312857,481]482483assert image.shape == (1, 512, 512, 3)484485assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2486487# with sega enabled488# generator = torch.manual_seed(seed)489generator.manual_seed(seed)490output = pipe(491[prompt],492generator=generator,493guidance_scale=guidance_scale,494num_inference_steps=50,495output_type="np",496width=512,497height=512,498**edit,499)500501image = output.images502image_slice = image[0, -3:, -3:, -1]503expected_slice = [5040.73553365,5050.7537271,5060.74341905,5070.66480356,5080.6472925,5090.63039416,5100.64812905,5110.6749717,5120.6517102,513]514515assert image.shape == (1, 512, 512, 3)516517assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2518519def test_guidance_fp16(self):520torch_device = "cuda"521pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)522pipe = pipe.to(torch_device)523pipe.set_progress_bar_config(disable=None)524525prompt = "a photo of a cat"526edit = {527"editing_prompt": ["sunglasses"],528"reverse_editing_direction": [False],529"edit_warmup_steps": 10,530"edit_guidance_scale": 6,531"edit_threshold": 0.95,532"edit_momentum_scale": 0.5,533"edit_mom_beta": 0.6,534}535536seed = 3537guidance_scale = 7538539# no sega enabled540generator = torch.Generator(torch_device)541generator.manual_seed(seed)542output = pipe(543[prompt],544generator=generator,545guidance_scale=guidance_scale,546num_inference_steps=50,547output_type="np",548width=512,549height=512,550)551552image = output.images553image_slice = image[0, -3:, -3:, -1]554expected_slice = [5550.34887695,5560.3876953,5570.375,5580.34423828,5590.3581543,5600.35717773,5610.3383789,5620.34570312,5630.359375,564]565566assert image.shape == (1, 512, 512, 3)567568assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2569570# with sega enabled571# generator = torch.manual_seed(seed)572generator.manual_seed(seed)573output = pipe(574[prompt],575generator=generator,576guidance_scale=guidance_scale,577num_inference_steps=50,578output_type="np",579width=512,580height=512,581**edit,582)583584image = output.images585image_slice = image[0, -3:, -3:, -1]586expected_slice = [5870.42285156,5880.36914062,5890.29077148,5900.42041016,5910.41918945,5920.35498047,5930.3618164,5940.4423828,5950.43115234,596]597598assert image.shape == (1, 512, 512, 3)599600assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2601602603