Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
1450 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from PIL import Image22from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2324from diffusers import (25AutoencoderKL,26DPMSolverMultistepScheduler,27LMSDiscreteScheduler,28PNDMScheduler,29StableDiffusionInpaintPipeline,30UNet2DConditionModel,31)32from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image33from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device34from diffusers.utils.testing_utils import require_torch_gpu3536from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS37from ...test_pipelines_common import PipelineTesterMixin383940torch.backends.cuda.matmul.allow_tf32 = False414243class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):44pipeline_class = StableDiffusionInpaintPipeline45params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS46batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS4748def get_dummy_components(self):49torch.manual_seed(0)50unet = UNet2DConditionModel(51block_out_channels=(32, 64),52layers_per_block=2,53sample_size=32,54in_channels=9,55out_channels=4,56down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),57up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),58cross_attention_dim=32,59)60scheduler = PNDMScheduler(skip_prk_steps=True)61torch.manual_seed(0)62vae = AutoencoderKL(63block_out_channels=[32, 64],64in_channels=3,65out_channels=3,66down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],67up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],68latent_channels=4,69)70torch.manual_seed(0)71text_encoder_config = CLIPTextConfig(72bos_token_id=0,73eos_token_id=2,74hidden_size=32,75intermediate_size=37,76layer_norm_eps=1e-05,77num_attention_heads=4,78num_hidden_layers=5,79pad_token_id=1,80vocab_size=1000,81)82text_encoder = CLIPTextModel(text_encoder_config)83tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")8485components = {86"unet": unet,87"scheduler": scheduler,88"vae": vae,89"text_encoder": text_encoder,90"tokenizer": tokenizer,91"safety_checker": None,92"feature_extractor": None,93}94return components9596def get_dummy_inputs(self, device, seed=0):97# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched98image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)99image = image.cpu().permute(0, 2, 3, 1)[0]100init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))101mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))102if str(device).startswith("mps"):103generator = torch.manual_seed(seed)104else:105generator = torch.Generator(device=device).manual_seed(seed)106inputs = {107"prompt": "A painting of a squirrel eating a burger",108"image": init_image,109"mask_image": mask_image,110"generator": generator,111"num_inference_steps": 2,112"guidance_scale": 6.0,113"output_type": "numpy",114}115return inputs116117def test_stable_diffusion_inpaint(self):118device = "cpu" # ensure determinism for the device-dependent torch.Generator119components = self.get_dummy_components()120sd_pipe = StableDiffusionInpaintPipeline(**components)121sd_pipe = sd_pipe.to(device)122sd_pipe.set_progress_bar_config(disable=None)123124inputs = self.get_dummy_inputs(device)125image = sd_pipe(**inputs).images126image_slice = image[0, -3:, -3:, -1]127128assert image.shape == (1, 64, 64, 3)129expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])130131assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2132133def test_stable_diffusion_inpaint_image_tensor(self):134device = "cpu" # ensure determinism for the device-dependent torch.Generator135components = self.get_dummy_components()136sd_pipe = StableDiffusionInpaintPipeline(**components)137sd_pipe = sd_pipe.to(device)138sd_pipe.set_progress_bar_config(disable=None)139140inputs = self.get_dummy_inputs(device)141output = sd_pipe(**inputs)142out_pil = output.images143144inputs = self.get_dummy_inputs(device)145inputs["image"] = torch.tensor(np.array(inputs["image"]) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0)146inputs["mask_image"] = torch.tensor(np.array(inputs["mask_image"]) / 255).permute(2, 0, 1)[:1].unsqueeze(0)147output = sd_pipe(**inputs)148out_tensor = output.images149150assert out_pil.shape == (1, 64, 64, 3)151assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2152153154@slow155@require_torch_gpu156class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):157def setUp(self):158super().setUp()159160def tearDown(self):161super().tearDown()162gc.collect()163torch.cuda.empty_cache()164165def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):166generator = torch.Generator(device=generator_device).manual_seed(seed)167init_image = load_image(168"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"169"/stable_diffusion_inpaint/input_bench_image.png"170)171mask_image = load_image(172"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"173"/stable_diffusion_inpaint/input_bench_mask.png"174)175inputs = {176"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",177"image": init_image,178"mask_image": mask_image,179"generator": generator,180"num_inference_steps": 3,181"guidance_scale": 7.5,182"output_type": "numpy",183}184return inputs185186def test_stable_diffusion_inpaint_ddim(self):187pipe = StableDiffusionInpaintPipeline.from_pretrained(188"runwayml/stable-diffusion-inpainting", safety_checker=None189)190pipe.to(torch_device)191pipe.set_progress_bar_config(disable=None)192pipe.enable_attention_slicing()193194inputs = self.get_inputs(torch_device)195image = pipe(**inputs).images196image_slice = image[0, 253:256, 253:256, -1].flatten()197198assert image.shape == (1, 512, 512, 3)199expected_slice = np.array([0.0427, 0.0460, 0.0483, 0.0460, 0.0584, 0.0521, 0.1549, 0.1695, 0.1794])200201assert np.abs(expected_slice - image_slice).max() < 1e-4202203def test_stable_diffusion_inpaint_fp16(self):204pipe = StableDiffusionInpaintPipeline.from_pretrained(205"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None206)207pipe.to(torch_device)208pipe.set_progress_bar_config(disable=None)209pipe.enable_attention_slicing()210211inputs = self.get_inputs(torch_device, dtype=torch.float16)212image = pipe(**inputs).images213image_slice = image[0, 253:256, 253:256, -1].flatten()214215assert image.shape == (1, 512, 512, 3)216expected_slice = np.array([0.1350, 0.1123, 0.1350, 0.1641, 0.1328, 0.1230, 0.1289, 0.1531, 0.1687])217218assert np.abs(expected_slice - image_slice).max() < 5e-2219220def test_stable_diffusion_inpaint_pndm(self):221pipe = StableDiffusionInpaintPipeline.from_pretrained(222"runwayml/stable-diffusion-inpainting", safety_checker=None223)224pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)225pipe.to(torch_device)226pipe.set_progress_bar_config(disable=None)227pipe.enable_attention_slicing()228229inputs = self.get_inputs(torch_device)230image = pipe(**inputs).images231image_slice = image[0, 253:256, 253:256, -1].flatten()232233assert image.shape == (1, 512, 512, 3)234expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])235236assert np.abs(expected_slice - image_slice).max() < 1e-4237238def test_stable_diffusion_inpaint_k_lms(self):239pipe = StableDiffusionInpaintPipeline.from_pretrained(240"runwayml/stable-diffusion-inpainting", safety_checker=None241)242pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)243pipe.to(torch_device)244pipe.set_progress_bar_config(disable=None)245pipe.enable_attention_slicing()246247inputs = self.get_inputs(torch_device)248image = pipe(**inputs).images249image_slice = image[0, 253:256, 253:256, -1].flatten()250251assert image.shape == (1, 512, 512, 3)252expected_slice = np.array([0.9314, 0.7575, 0.9432, 0.8885, 0.9028, 0.7298, 0.9811, 0.9667, 0.7633])253254assert np.abs(expected_slice - image_slice).max() < 1e-4255256def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):257torch.cuda.empty_cache()258torch.cuda.reset_max_memory_allocated()259torch.cuda.reset_peak_memory_stats()260261pipe = StableDiffusionInpaintPipeline.from_pretrained(262"runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float16263)264pipe = pipe.to(torch_device)265pipe.set_progress_bar_config(disable=None)266pipe.enable_attention_slicing(1)267pipe.enable_sequential_cpu_offload()268269inputs = self.get_inputs(torch_device, dtype=torch.float16)270_ = pipe(**inputs)271272mem_bytes = torch.cuda.max_memory_allocated()273# make sure that less than 2.2 GB is allocated274assert mem_bytes < 2.2 * 10**9275276277@nightly278@require_torch_gpu279class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):280def tearDown(self):281super().tearDown()282gc.collect()283torch.cuda.empty_cache()284285def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):286generator = torch.Generator(device=generator_device).manual_seed(seed)287init_image = load_image(288"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"289"/stable_diffusion_inpaint/input_bench_image.png"290)291mask_image = load_image(292"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"293"/stable_diffusion_inpaint/input_bench_mask.png"294)295inputs = {296"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",297"image": init_image,298"mask_image": mask_image,299"generator": generator,300"num_inference_steps": 50,301"guidance_scale": 7.5,302"output_type": "numpy",303}304return inputs305306def test_inpaint_ddim(self):307sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")308sd_pipe.to(torch_device)309sd_pipe.set_progress_bar_config(disable=None)310311inputs = self.get_inputs(torch_device)312image = sd_pipe(**inputs).images[0]313314expected_image = load_numpy(315"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"316"/stable_diffusion_inpaint/stable_diffusion_inpaint_ddim.npy"317)318max_diff = np.abs(expected_image - image).max()319assert max_diff < 1e-3320321def test_inpaint_pndm(self):322sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")323sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config)324sd_pipe.to(torch_device)325sd_pipe.set_progress_bar_config(disable=None)326327inputs = self.get_inputs(torch_device)328image = sd_pipe(**inputs).images[0]329330expected_image = load_numpy(331"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"332"/stable_diffusion_inpaint/stable_diffusion_inpaint_pndm.npy"333)334max_diff = np.abs(expected_image - image).max()335assert max_diff < 1e-3336337def test_inpaint_lms(self):338sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")339sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)340sd_pipe.to(torch_device)341sd_pipe.set_progress_bar_config(disable=None)342343inputs = self.get_inputs(torch_device)344image = sd_pipe(**inputs).images[0]345346expected_image = load_numpy(347"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"348"/stable_diffusion_inpaint/stable_diffusion_inpaint_lms.npy"349)350max_diff = np.abs(expected_image - image).max()351assert max_diff < 1e-3352353def test_inpaint_dpm(self):354sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")355sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)356sd_pipe.to(torch_device)357sd_pipe.set_progress_bar_config(disable=None)358359inputs = self.get_inputs(torch_device)360inputs["num_inference_steps"] = 30361image = sd_pipe(**inputs).images[0]362363expected_image = load_numpy(364"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"365"/stable_diffusion_inpaint/stable_diffusion_inpaint_dpm_multi.npy"366)367max_diff = np.abs(expected_image - image).max()368assert max_diff < 1e-3369370371class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):372def test_pil_inputs(self):373im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)374im = Image.fromarray(im)375mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5376mask = Image.fromarray((mask * 255).astype(np.uint8))377378t_mask, t_masked = prepare_mask_and_masked_image(im, mask)379380self.assertTrue(isinstance(t_mask, torch.Tensor))381self.assertTrue(isinstance(t_masked, torch.Tensor))382383self.assertEqual(t_mask.ndim, 4)384self.assertEqual(t_masked.ndim, 4)385386self.assertEqual(t_mask.shape, (1, 1, 32, 32))387self.assertEqual(t_masked.shape, (1, 3, 32, 32))388389self.assertTrue(t_mask.dtype == torch.float32)390self.assertTrue(t_masked.dtype == torch.float32)391392self.assertTrue(t_mask.min() >= 0.0)393self.assertTrue(t_mask.max() <= 1.0)394self.assertTrue(t_masked.min() >= -1.0)395self.assertTrue(t_masked.min() <= 1.0)396397self.assertTrue(t_mask.sum() > 0.0)398399def test_np_inputs(self):400im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)401im_pil = Image.fromarray(im_np)402mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5403mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))404405t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)406t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil)407408self.assertTrue((t_mask_np == t_mask_pil).all())409self.assertTrue((t_masked_np == t_masked_pil).all())410411def test_torch_3D_2D_inputs(self):412im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)413mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5414im_np = im_tensor.numpy().transpose(1, 2, 0)415mask_np = mask_tensor.numpy()416417t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)418t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)419420self.assertTrue((t_mask_tensor == t_mask_np).all())421self.assertTrue((t_masked_tensor == t_masked_np).all())422423def test_torch_3D_3D_inputs(self):424im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)425mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5426im_np = im_tensor.numpy().transpose(1, 2, 0)427mask_np = mask_tensor.numpy()[0]428429t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)430t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)431432self.assertTrue((t_mask_tensor == t_mask_np).all())433self.assertTrue((t_masked_tensor == t_masked_np).all())434435def test_torch_4D_2D_inputs(self):436im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)437mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5438im_np = im_tensor.numpy()[0].transpose(1, 2, 0)439mask_np = mask_tensor.numpy()440441t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)442t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)443444self.assertTrue((t_mask_tensor == t_mask_np).all())445self.assertTrue((t_masked_tensor == t_masked_np).all())446447def test_torch_4D_3D_inputs(self):448im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)449mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5450im_np = im_tensor.numpy()[0].transpose(1, 2, 0)451mask_np = mask_tensor.numpy()[0]452453t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)454t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)455456self.assertTrue((t_mask_tensor == t_mask_np).all())457self.assertTrue((t_masked_tensor == t_masked_np).all())458459def test_torch_4D_4D_inputs(self):460im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)461mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5462im_np = im_tensor.numpy()[0].transpose(1, 2, 0)463mask_np = mask_tensor.numpy()[0][0]464465t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)466t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)467468self.assertTrue((t_mask_tensor == t_mask_np).all())469self.assertTrue((t_masked_tensor == t_masked_np).all())470471def test_torch_batch_4D_3D(self):472im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)473mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5474475im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]476mask_nps = [mask.numpy() for mask in mask_tensor]477478t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)479nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]480t_mask_np = torch.cat([n[0] for n in nps])481t_masked_np = torch.cat([n[1] for n in nps])482483self.assertTrue((t_mask_tensor == t_mask_np).all())484self.assertTrue((t_masked_tensor == t_masked_np).all())485486def test_torch_batch_4D_4D(self):487im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)488mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5489490im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]491mask_nps = [mask.numpy()[0] for mask in mask_tensor]492493t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)494nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]495t_mask_np = torch.cat([n[0] for n in nps])496t_masked_np = torch.cat([n[1] for n in nps])497498self.assertTrue((t_mask_tensor == t_mask_np).all())499self.assertTrue((t_masked_tensor == t_masked_np).all())500501def test_shape_mismatch(self):502# test height and width503with self.assertRaises(AssertionError):504prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64))505# test batch dim506with self.assertRaises(AssertionError):507prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64))508# test batch dim509with self.assertRaises(AssertionError):510prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64))511512def test_type_mismatch(self):513# test tensors-only514with self.assertRaises(TypeError):515prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy())516# test tensors-only517with self.assertRaises(TypeError):518prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32))519520def test_channels_first(self):521# test channels first for 3D tensors522with self.assertRaises(AssertionError):523prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32))524525def test_tensor_range(self):526# test im <= 1527with self.assertRaises(ValueError):528prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32))529# test im >= -1530with self.assertRaises(ValueError):531prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32))532# test mask <= 1533with self.assertRaises(ValueError):534prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)535# test mask >= 0536with self.assertRaises(ValueError):537prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)538539540