Path: blob/main/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
1450 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from PIL import Image22from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2324from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel25from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device26from diffusers.utils.testing_utils import require_torch_gpu, slow2728from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS29from ...test_pipelines_common import PipelineTesterMixin303132torch.backends.cuda.matmul.allow_tf32 = False333435class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):36pipeline_class = StableDiffusionInpaintPipeline37params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS38batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS3940def get_dummy_components(self):41torch.manual_seed(0)42unet = UNet2DConditionModel(43block_out_channels=(32, 64),44layers_per_block=2,45sample_size=32,46in_channels=9,47out_channels=4,48down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),49up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),50cross_attention_dim=32,51# SD2-specific config below52attention_head_dim=(2, 4),53use_linear_projection=True,54)55scheduler = PNDMScheduler(skip_prk_steps=True)56torch.manual_seed(0)57vae = AutoencoderKL(58block_out_channels=[32, 64],59in_channels=3,60out_channels=3,61down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],62up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],63latent_channels=4,64sample_size=128,65)66torch.manual_seed(0)67text_encoder_config = CLIPTextConfig(68bos_token_id=0,69eos_token_id=2,70hidden_size=32,71intermediate_size=37,72layer_norm_eps=1e-05,73num_attention_heads=4,74num_hidden_layers=5,75pad_token_id=1,76vocab_size=1000,77# SD2-specific config below78hidden_act="gelu",79projection_dim=512,80)81text_encoder = CLIPTextModel(text_encoder_config)82tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")8384components = {85"unet": unet,86"scheduler": scheduler,87"vae": vae,88"text_encoder": text_encoder,89"tokenizer": tokenizer,90"safety_checker": None,91"feature_extractor": None,92}93return components9495def get_dummy_inputs(self, device, seed=0):96# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched97image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)98image = image.cpu().permute(0, 2, 3, 1)[0]99init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))100mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))101if str(device).startswith("mps"):102generator = torch.manual_seed(seed)103else:104generator = torch.Generator(device=device).manual_seed(seed)105inputs = {106"prompt": "A painting of a squirrel eating a burger",107"image": init_image,108"mask_image": mask_image,109"generator": generator,110"num_inference_steps": 2,111"guidance_scale": 6.0,112"output_type": "numpy",113}114return inputs115116def test_stable_diffusion_inpaint(self):117device = "cpu" # ensure determinism for the device-dependent torch.Generator118components = self.get_dummy_components()119sd_pipe = StableDiffusionInpaintPipeline(**components)120sd_pipe = sd_pipe.to(device)121sd_pipe.set_progress_bar_config(disable=None)122123inputs = self.get_dummy_inputs(device)124image = sd_pipe(**inputs).images125image_slice = image[0, -3:, -3:, -1]126127assert image.shape == (1, 64, 64, 3)128expected_slice = np.array([0.4727, 0.5735, 0.3941, 0.5446, 0.5926, 0.4394, 0.5062, 0.4654, 0.4476])129130assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2131132133@slow134@require_torch_gpu135class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):136def tearDown(self):137# clean up the VRAM after each test138super().tearDown()139gc.collect()140torch.cuda.empty_cache()141142def test_stable_diffusion_inpaint_pipeline(self):143init_image = load_image(144"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"145"/sd2-inpaint/init_image.png"146)147mask_image = load_image(148"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"149)150expected_image = load_numpy(151"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"152"/yellow_cat_sitting_on_a_park_bench.npy"153)154155model_id = "stabilityai/stable-diffusion-2-inpainting"156pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)157pipe.to(torch_device)158pipe.set_progress_bar_config(disable=None)159pipe.enable_attention_slicing()160161prompt = "Face of a yellow cat, high resolution, sitting on a park bench"162163generator = torch.manual_seed(0)164output = pipe(165prompt=prompt,166image=init_image,167mask_image=mask_image,168generator=generator,169output_type="np",170)171image = output.images[0]172173assert image.shape == (512, 512, 3)174assert np.abs(expected_image - image).max() < 1e-3175176def test_stable_diffusion_inpaint_pipeline_fp16(self):177init_image = load_image(178"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"179"/sd2-inpaint/init_image.png"180)181mask_image = load_image(182"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"183)184expected_image = load_numpy(185"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"186"/yellow_cat_sitting_on_a_park_bench_fp16.npy"187)188189model_id = "stabilityai/stable-diffusion-2-inpainting"190pipe = StableDiffusionInpaintPipeline.from_pretrained(191model_id,192torch_dtype=torch.float16,193safety_checker=None,194)195pipe.to(torch_device)196pipe.set_progress_bar_config(disable=None)197pipe.enable_attention_slicing()198199prompt = "Face of a yellow cat, high resolution, sitting on a park bench"200201generator = torch.manual_seed(0)202output = pipe(203prompt=prompt,204image=init_image,205mask_image=mask_image,206generator=generator,207output_type="np",208)209image = output.images[0]210211assert image.shape == (512, 512, 3)212assert np.abs(expected_image - image).max() < 5e-1213214def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):215torch.cuda.empty_cache()216torch.cuda.reset_max_memory_allocated()217torch.cuda.reset_peak_memory_stats()218219init_image = load_image(220"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"221"/sd2-inpaint/init_image.png"222)223mask_image = load_image(224"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"225)226227model_id = "stabilityai/stable-diffusion-2-inpainting"228pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")229pipe = StableDiffusionInpaintPipeline.from_pretrained(230model_id,231safety_checker=None,232scheduler=pndm,233torch_dtype=torch.float16,234)235pipe.to(torch_device)236pipe.set_progress_bar_config(disable=None)237pipe.enable_attention_slicing(1)238pipe.enable_sequential_cpu_offload()239240prompt = "Face of a yellow cat, high resolution, sitting on a park bench"241242generator = torch.manual_seed(0)243_ = pipe(244prompt=prompt,245image=init_image,246mask_image=mask_image,247generator=generator,248num_inference_steps=2,249output_type="np",250)251252mem_bytes = torch.cuda.max_memory_allocated()253# make sure that less than 2.65 GB is allocated254assert mem_bytes < 2.65 * 10**9255256257