Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from PIL import Image22from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2324from diffusers import (25AutoencoderKL,26DDIMScheduler,27DPMSolverMultistepScheduler,28LMSDiscreteScheduler,29PNDMScheduler,30StableDiffusionInpaintPipelineLegacy,31UNet2DConditionModel,32UNet2DModel,33VQModel,34)35from diffusers.utils import floats_tensor, load_image, nightly, slow, torch_device36from diffusers.utils.testing_utils import load_numpy, require_torch_gpu373839torch.backends.cuda.matmul.allow_tf32 = False404142class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):43def tearDown(self):44# clean up the VRAM after each test45super().tearDown()46gc.collect()47torch.cuda.empty_cache()4849@property50def dummy_image(self):51batch_size = 152num_channels = 353sizes = (32, 32)5455image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)56return image5758@property59def dummy_uncond_unet(self):60torch.manual_seed(0)61model = UNet2DModel(62block_out_channels=(32, 64),63layers_per_block=2,64sample_size=32,65in_channels=3,66out_channels=3,67down_block_types=("DownBlock2D", "AttnDownBlock2D"),68up_block_types=("AttnUpBlock2D", "UpBlock2D"),69)70return model7172@property73def dummy_cond_unet(self):74torch.manual_seed(0)75model = UNet2DConditionModel(76block_out_channels=(32, 64),77layers_per_block=2,78sample_size=32,79in_channels=4,80out_channels=4,81down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),82up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),83cross_attention_dim=32,84)85return model8687@property88def dummy_cond_unet_inpaint(self):89torch.manual_seed(0)90model = UNet2DConditionModel(91block_out_channels=(32, 64),92layers_per_block=2,93sample_size=32,94in_channels=9,95out_channels=4,96down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),97up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),98cross_attention_dim=32,99)100return model101102@property103def dummy_vq_model(self):104torch.manual_seed(0)105model = VQModel(106block_out_channels=[32, 64],107in_channels=3,108out_channels=3,109down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],110up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],111latent_channels=3,112)113return model114115@property116def dummy_vae(self):117torch.manual_seed(0)118model = AutoencoderKL(119block_out_channels=[32, 64],120in_channels=3,121out_channels=3,122down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],123up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],124latent_channels=4,125)126return model127128@property129def dummy_text_encoder(self):130torch.manual_seed(0)131config = CLIPTextConfig(132bos_token_id=0,133eos_token_id=2,134hidden_size=32,135intermediate_size=37,136layer_norm_eps=1e-05,137num_attention_heads=4,138num_hidden_layers=5,139pad_token_id=1,140vocab_size=1000,141)142return CLIPTextModel(config)143144@property145def dummy_extractor(self):146def extract(*args, **kwargs):147class Out:148def __init__(self):149self.pixel_values = torch.ones([0])150151def to(self, device):152self.pixel_values.to(device)153return self154155return Out()156157return extract158159def test_stable_diffusion_inpaint_legacy(self):160device = "cpu" # ensure determinism for the device-dependent torch.Generator161unet = self.dummy_cond_unet162scheduler = PNDMScheduler(skip_prk_steps=True)163vae = self.dummy_vae164bert = self.dummy_text_encoder165tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")166167image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]168init_image = Image.fromarray(np.uint8(image)).convert("RGB")169mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))170171# make sure here that pndm scheduler skips prk172sd_pipe = StableDiffusionInpaintPipelineLegacy(173unet=unet,174scheduler=scheduler,175vae=vae,176text_encoder=bert,177tokenizer=tokenizer,178safety_checker=None,179feature_extractor=self.dummy_extractor,180)181sd_pipe = sd_pipe.to(device)182sd_pipe.set_progress_bar_config(disable=None)183184prompt = "A painting of a squirrel eating a burger"185generator = torch.Generator(device=device).manual_seed(0)186output = sd_pipe(187[prompt],188generator=generator,189guidance_scale=6.0,190num_inference_steps=2,191output_type="np",192image=init_image,193mask_image=mask_image,194)195196image = output.images197198generator = torch.Generator(device=device).manual_seed(0)199image_from_tuple = sd_pipe(200[prompt],201generator=generator,202guidance_scale=6.0,203num_inference_steps=2,204output_type="np",205image=init_image,206mask_image=mask_image,207return_dict=False,208)[0]209210image_slice = image[0, -3:, -3:, -1]211image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]212213assert image.shape == (1, 32, 32, 3)214expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])215216assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2217assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2218219def test_stable_diffusion_inpaint_legacy_negative_prompt(self):220device = "cpu" # ensure determinism for the device-dependent torch.Generator221unet = self.dummy_cond_unet222scheduler = PNDMScheduler(skip_prk_steps=True)223vae = self.dummy_vae224bert = self.dummy_text_encoder225tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")226227image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]228init_image = Image.fromarray(np.uint8(image)).convert("RGB")229mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))230231# make sure here that pndm scheduler skips prk232sd_pipe = StableDiffusionInpaintPipelineLegacy(233unet=unet,234scheduler=scheduler,235vae=vae,236text_encoder=bert,237tokenizer=tokenizer,238safety_checker=None,239feature_extractor=self.dummy_extractor,240)241sd_pipe = sd_pipe.to(device)242sd_pipe.set_progress_bar_config(disable=None)243244prompt = "A painting of a squirrel eating a burger"245negative_prompt = "french fries"246generator = torch.Generator(device=device).manual_seed(0)247output = sd_pipe(248prompt,249negative_prompt=negative_prompt,250generator=generator,251guidance_scale=6.0,252num_inference_steps=2,253output_type="np",254image=init_image,255mask_image=mask_image,256)257258image = output.images259image_slice = image[0, -3:, -3:, -1]260261assert image.shape == (1, 32, 32, 3)262expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])263264assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2265266def test_stable_diffusion_inpaint_legacy_num_images_per_prompt(self):267device = "cpu"268unet = self.dummy_cond_unet269scheduler = PNDMScheduler(skip_prk_steps=True)270vae = self.dummy_vae271bert = self.dummy_text_encoder272tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")273274image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]275init_image = Image.fromarray(np.uint8(image)).convert("RGB")276mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))277278# make sure here that pndm scheduler skips prk279sd_pipe = StableDiffusionInpaintPipelineLegacy(280unet=unet,281scheduler=scheduler,282vae=vae,283text_encoder=bert,284tokenizer=tokenizer,285safety_checker=None,286feature_extractor=self.dummy_extractor,287)288sd_pipe = sd_pipe.to(device)289sd_pipe.set_progress_bar_config(disable=None)290291prompt = "A painting of a squirrel eating a burger"292293# test num_images_per_prompt=1 (default)294images = sd_pipe(295prompt,296num_inference_steps=2,297output_type="np",298image=init_image,299mask_image=mask_image,300).images301302assert images.shape == (1, 32, 32, 3)303304# test num_images_per_prompt=1 (default) for batch of prompts305batch_size = 2306images = sd_pipe(307[prompt] * batch_size,308num_inference_steps=2,309output_type="np",310image=init_image,311mask_image=mask_image,312).images313314assert images.shape == (batch_size, 32, 32, 3)315316# test num_images_per_prompt for single prompt317num_images_per_prompt = 2318images = sd_pipe(319prompt,320num_inference_steps=2,321output_type="np",322image=init_image,323mask_image=mask_image,324num_images_per_prompt=num_images_per_prompt,325).images326327assert images.shape == (num_images_per_prompt, 32, 32, 3)328329# test num_images_per_prompt for batch of prompts330batch_size = 2331images = sd_pipe(332[prompt] * batch_size,333num_inference_steps=2,334output_type="np",335image=init_image,336mask_image=mask_image,337num_images_per_prompt=num_images_per_prompt,338).images339340assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)341342343@slow344@require_torch_gpu345class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):346def tearDown(self):347super().tearDown()348gc.collect()349torch.cuda.empty_cache()350351def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):352generator = torch.Generator(device=generator_device).manual_seed(seed)353init_image = load_image(354"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"355"/stable_diffusion_inpaint/input_bench_image.png"356)357mask_image = load_image(358"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"359"/stable_diffusion_inpaint/input_bench_mask.png"360)361inputs = {362"prompt": "A red cat sitting on a park bench",363"image": init_image,364"mask_image": mask_image,365"generator": generator,366"num_inference_steps": 3,367"strength": 0.75,368"guidance_scale": 7.5,369"output_type": "numpy",370}371return inputs372373def test_stable_diffusion_inpaint_legacy_pndm(self):374pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(375"CompVis/stable-diffusion-v1-4", safety_checker=None376)377pipe.to(torch_device)378pipe.set_progress_bar_config(disable=None)379pipe.enable_attention_slicing()380381inputs = self.get_inputs(torch_device)382image = pipe(**inputs).images383image_slice = image[0, 253:256, 253:256, -1].flatten()384385assert image.shape == (1, 512, 512, 3)386expected_slice = np.array([0.5665, 0.6117, 0.6430, 0.4057, 0.4594, 0.5658, 0.1596, 0.3106, 0.4305])387388assert np.abs(expected_slice - image_slice).max() < 1e-4389390def test_stable_diffusion_inpaint_legacy_k_lms(self):391pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(392"CompVis/stable-diffusion-v1-4", safety_checker=None393)394pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)395pipe.to(torch_device)396pipe.set_progress_bar_config(disable=None)397pipe.enable_attention_slicing()398399inputs = self.get_inputs(torch_device)400image = pipe(**inputs).images401image_slice = image[0, 253:256, 253:256, -1].flatten()402403assert image.shape == (1, 512, 512, 3)404expected_slice = np.array([0.4534, 0.4467, 0.4329, 0.4329, 0.4339, 0.4220, 0.4244, 0.4332, 0.4426])405406assert np.abs(expected_slice - image_slice).max() < 1e-4407408def test_stable_diffusion_inpaint_legacy_intermediate_state(self):409number_of_steps = 0410411def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:412callback_fn.has_been_called = True413nonlocal number_of_steps414number_of_steps += 1415if step == 1:416latents = latents.detach().cpu().numpy()417assert latents.shape == (1, 4, 64, 64)418latents_slice = latents[0, -3:, -3:, -1]419expected_slice = np.array([0.5977, 1.5449, 1.0586, -0.3250, 0.7383, -0.0862, 0.4631, -0.2571, -1.1289])420421assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3422elif step == 2:423latents = latents.detach().cpu().numpy()424assert latents.shape == (1, 4, 64, 64)425latents_slice = latents[0, -3:, -3:, -1]426expected_slice = np.array([0.5190, 1.1621, 0.6885, 0.2424, 0.3337, -0.1617, 0.6914, -0.1957, -0.5474])427428assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3429430callback_fn.has_been_called = False431432pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(433"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16434)435pipe = pipe.to(torch_device)436pipe.set_progress_bar_config(disable=None)437pipe.enable_attention_slicing()438439inputs = self.get_inputs(torch_device, dtype=torch.float16)440pipe(**inputs, callback=callback_fn, callback_steps=1)441assert callback_fn.has_been_called442assert number_of_steps == 2443444445@nightly446@require_torch_gpu447class StableDiffusionInpaintLegacyPipelineNightlyTests(unittest.TestCase):448def tearDown(self):449super().tearDown()450gc.collect()451torch.cuda.empty_cache()452453def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):454generator = torch.Generator(device=generator_device).manual_seed(seed)455init_image = load_image(456"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"457"/stable_diffusion_inpaint/input_bench_image.png"458)459mask_image = load_image(460"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"461"/stable_diffusion_inpaint/input_bench_mask.png"462)463inputs = {464"prompt": "A red cat sitting on a park bench",465"image": init_image,466"mask_image": mask_image,467"generator": generator,468"num_inference_steps": 50,469"strength": 0.75,470"guidance_scale": 7.5,471"output_type": "numpy",472}473return inputs474475def test_inpaint_pndm(self):476sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")477sd_pipe.to(torch_device)478sd_pipe.set_progress_bar_config(disable=None)479480inputs = self.get_inputs(torch_device)481image = sd_pipe(**inputs).images[0]482483expected_image = load_numpy(484"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"485"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_pndm.npy"486)487max_diff = np.abs(expected_image - image).max()488assert max_diff < 1e-3489490def test_inpaint_ddim(self):491sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")492sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)493sd_pipe.to(torch_device)494sd_pipe.set_progress_bar_config(disable=None)495496inputs = self.get_inputs(torch_device)497image = sd_pipe(**inputs).images[0]498499expected_image = load_numpy(500"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"501"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_ddim.npy"502)503max_diff = np.abs(expected_image - image).max()504assert max_diff < 1e-3505506def test_inpaint_lms(self):507sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")508sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)509sd_pipe.to(torch_device)510sd_pipe.set_progress_bar_config(disable=None)511512inputs = self.get_inputs(torch_device)513image = sd_pipe(**inputs).images[0]514515expected_image = load_numpy(516"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"517"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_lms.npy"518)519max_diff = np.abs(expected_image - image).max()520assert max_diff < 1e-3521522def test_inpaint_dpm(self):523sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")524sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)525sd_pipe.to(torch_device)526sd_pipe.set_progress_bar_config(disable=None)527528inputs = self.get_inputs(torch_device)529inputs["num_inference_steps"] = 30530image = sd_pipe(**inputs).images[0]531532expected_image = load_numpy(533"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"534"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_dpm_multi.npy"535)536max_diff = np.abs(expected_image - image).max()537assert max_diff < 1e-3538539540