Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2223from diffusers import (24AutoencoderKL,25DDIMScheduler,26DPMSolverMultistepScheduler,27LMSDiscreteScheduler,28PNDMScheduler,29StableDiffusionImg2ImgPipeline,30UNet2DConditionModel,31)32from diffusers.image_processor import VaeImageProcessor33from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device34from diffusers.utils.testing_utils import require_torch_gpu, skip_mps3536from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS37from ...test_pipelines_common import PipelineTesterMixin383940torch.backends.cuda.matmul.allow_tf32 = False414243class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):44pipeline_class = StableDiffusionImg2ImgPipeline45params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}46required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}47batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS4849def get_dummy_components(self):50torch.manual_seed(0)51unet = UNet2DConditionModel(52block_out_channels=(32, 64),53layers_per_block=2,54sample_size=32,55in_channels=4,56out_channels=4,57down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),58up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),59cross_attention_dim=32,60)61scheduler = PNDMScheduler(skip_prk_steps=True)62torch.manual_seed(0)63vae = AutoencoderKL(64block_out_channels=[32, 64],65in_channels=3,66out_channels=3,67down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],68up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],69latent_channels=4,70)71torch.manual_seed(0)72text_encoder_config = CLIPTextConfig(73bos_token_id=0,74eos_token_id=2,75hidden_size=32,76intermediate_size=37,77layer_norm_eps=1e-05,78num_attention_heads=4,79num_hidden_layers=5,80pad_token_id=1,81vocab_size=1000,82)83text_encoder = CLIPTextModel(text_encoder_config)84tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")8586components = {87"unet": unet,88"scheduler": scheduler,89"vae": vae,90"text_encoder": text_encoder,91"tokenizer": tokenizer,92"safety_checker": None,93"feature_extractor": None,94}95return components9697def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"):98image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)99if str(device).startswith("mps"):100generator = torch.manual_seed(seed)101else:102generator = torch.Generator(device=device).manual_seed(seed)103104if input_image_type == "pt":105input_image = image106elif input_image_type == "np":107input_image = image.cpu().numpy().transpose(0, 2, 3, 1)108elif input_image_type == "pil":109input_image = image.cpu().numpy().transpose(0, 2, 3, 1)110input_image = VaeImageProcessor.numpy_to_pil(input_image)111else:112raise ValueError(f"unsupported input_image_type {input_image_type}.")113114if output_type not in ["pt", "np", "pil"]:115raise ValueError(f"unsupported output_type {output_type}")116117inputs = {118"prompt": "A painting of a squirrel eating a burger",119"image": input_image,120"generator": generator,121"num_inference_steps": 2,122"guidance_scale": 6.0,123"output_type": output_type,124}125return inputs126127def test_stable_diffusion_img2img_default_case(self):128device = "cpu" # ensure determinism for the device-dependent torch.Generator129components = self.get_dummy_components()130sd_pipe = StableDiffusionImg2ImgPipeline(**components)131sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)132sd_pipe = sd_pipe.to(device)133sd_pipe.set_progress_bar_config(disable=None)134135inputs = self.get_dummy_inputs(device)136image = sd_pipe(**inputs).images137image_slice = image[0, -3:, -3:, -1]138139assert image.shape == (1, 32, 32, 3)140expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])141142assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3143144def test_stable_diffusion_img2img_negative_prompt(self):145device = "cpu" # ensure determinism for the device-dependent torch.Generator146components = self.get_dummy_components()147sd_pipe = StableDiffusionImg2ImgPipeline(**components)148sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)149sd_pipe = sd_pipe.to(device)150sd_pipe.set_progress_bar_config(disable=None)151152inputs = self.get_dummy_inputs(device)153negative_prompt = "french fries"154output = sd_pipe(**inputs, negative_prompt=negative_prompt)155image = output.images156image_slice = image[0, -3:, -3:, -1]157158assert image.shape == (1, 32, 32, 3)159expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365])160161assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3162163def test_stable_diffusion_img2img_multiple_init_images(self):164device = "cpu" # ensure determinism for the device-dependent torch.Generator165components = self.get_dummy_components()166sd_pipe = StableDiffusionImg2ImgPipeline(**components)167sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)168sd_pipe = sd_pipe.to(device)169sd_pipe.set_progress_bar_config(disable=None)170171inputs = self.get_dummy_inputs(device)172inputs["prompt"] = [inputs["prompt"]] * 2173inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)174image = sd_pipe(**inputs).images175image_slice = image[-1, -3:, -3:, -1]176177assert image.shape == (2, 32, 32, 3)178expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689])179180assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3181182def test_stable_diffusion_img2img_k_lms(self):183device = "cpu" # ensure determinism for the device-dependent torch.Generator184components = self.get_dummy_components()185components["scheduler"] = LMSDiscreteScheduler(186beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"187)188sd_pipe = StableDiffusionImg2ImgPipeline(**components)189sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)190sd_pipe = sd_pipe.to(device)191sd_pipe.set_progress_bar_config(disable=None)192193inputs = self.get_dummy_inputs(device)194image = sd_pipe(**inputs).images195image_slice = image[0, -3:, -3:, -1]196197assert image.shape == (1, 32, 32, 3)198expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])199200assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3201202@skip_mps203def test_save_load_local(self):204return super().test_save_load_local()205206@skip_mps207def test_dict_tuple_outputs_equivalent(self):208return super().test_dict_tuple_outputs_equivalent()209210@skip_mps211def test_save_load_optional_components(self):212return super().test_save_load_optional_components()213214@skip_mps215def test_attention_slicing_forward_pass(self):216return super().test_attention_slicing_forward_pass()217218@skip_mps219def test_pt_np_pil_outputs_equivalent(self):220device = "cpu"221components = self.get_dummy_components()222sd_pipe = StableDiffusionImg2ImgPipeline(**components)223sd_pipe = sd_pipe.to(device)224sd_pipe.set_progress_bar_config(disable=None)225226output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0]227output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0]228output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0]229230assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4231assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4232233@skip_mps234def test_image_types_consistent(self):235device = "cpu"236components = self.get_dummy_components()237sd_pipe = StableDiffusionImg2ImgPipeline(**components)238sd_pipe = sd_pipe.to(device)239sd_pipe.set_progress_bar_config(disable=None)240241output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0]242output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0]243output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0]244245assert np.abs(output_pt - output_np).max() <= 1e-4246assert np.abs(output_pil - output_np).max() <= 1e-2247248249@slow250@require_torch_gpu251class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):252def tearDown(self):253super().tearDown()254gc.collect()255torch.cuda.empty_cache()256257def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):258generator = torch.Generator(device=generator_device).manual_seed(seed)259init_image = load_image(260"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"261"/stable_diffusion_img2img/sketch-mountains-input.png"262)263inputs = {264"prompt": "a fantasy landscape, concept art, high resolution",265"image": init_image,266"generator": generator,267"num_inference_steps": 3,268"strength": 0.75,269"guidance_scale": 7.5,270"output_type": "np",271}272return inputs273274def test_stable_diffusion_img2img_default(self):275pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)276pipe.to(torch_device)277pipe.set_progress_bar_config(disable=None)278pipe.enable_attention_slicing()279280inputs = self.get_inputs(torch_device)281image = pipe(**inputs).images282image_slice = image[0, -3:, -3:, -1].flatten()283284assert image.shape == (1, 512, 768, 3)285expected_slice = np.array([0.4300, 0.4662, 0.4930, 0.3990, 0.4307, 0.4525, 0.3719, 0.4064, 0.3923])286287assert np.abs(expected_slice - image_slice).max() < 1e-3288289def test_stable_diffusion_img2img_k_lms(self):290pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)291pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)292pipe.to(torch_device)293pipe.set_progress_bar_config(disable=None)294pipe.enable_attention_slicing()295296inputs = self.get_inputs(torch_device)297image = pipe(**inputs).images298image_slice = image[0, -3:, -3:, -1].flatten()299300assert image.shape == (1, 512, 768, 3)301expected_slice = np.array([0.0389, 0.0346, 0.0415, 0.0290, 0.0218, 0.0210, 0.0408, 0.0567, 0.0271])302303assert np.abs(expected_slice - image_slice).max() < 1e-3304305def test_stable_diffusion_img2img_ddim(self):306pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)307pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)308pipe.to(torch_device)309pipe.set_progress_bar_config(disable=None)310pipe.enable_attention_slicing()311312inputs = self.get_inputs(torch_device)313image = pipe(**inputs).images314image_slice = image[0, -3:, -3:, -1].flatten()315316assert image.shape == (1, 512, 768, 3)317expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781])318319assert np.abs(expected_slice - image_slice).max() < 1e-3320321def test_stable_diffusion_img2img_intermediate_state(self):322number_of_steps = 0323324def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:325callback_fn.has_been_called = True326nonlocal number_of_steps327number_of_steps += 1328if step == 1:329latents = latents.detach().cpu().numpy()330assert latents.shape == (1, 4, 64, 96)331latents_slice = latents[0, -3:, -3:, -1]332expected_slice = np.array([-0.4958, 0.5107, 1.1045, 2.7539, 4.6680, 3.8320, 1.5049, 1.8633, 2.6523])333334assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2335elif step == 2:336latents = latents.detach().cpu().numpy()337assert latents.shape == (1, 4, 64, 96)338latents_slice = latents[0, -3:, -3:, -1]339expected_slice = np.array([-0.4956, 0.5078, 1.0918, 2.7520, 4.6484, 3.8125, 1.5146, 1.8633, 2.6367])340341assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2342343callback_fn.has_been_called = False344345pipe = StableDiffusionImg2ImgPipeline.from_pretrained(346"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16347)348pipe = pipe.to(torch_device)349pipe.set_progress_bar_config(disable=None)350pipe.enable_attention_slicing()351352inputs = self.get_inputs(torch_device, dtype=torch.float16)353pipe(**inputs, callback=callback_fn, callback_steps=1)354assert callback_fn.has_been_called355assert number_of_steps == 2356357def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):358torch.cuda.empty_cache()359torch.cuda.reset_max_memory_allocated()360torch.cuda.reset_peak_memory_stats()361362pipe = StableDiffusionImg2ImgPipeline.from_pretrained(363"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16364)365pipe = pipe.to(torch_device)366pipe.set_progress_bar_config(disable=None)367pipe.enable_attention_slicing(1)368pipe.enable_sequential_cpu_offload()369370inputs = self.get_inputs(torch_device, dtype=torch.float16)371_ = pipe(**inputs)372373mem_bytes = torch.cuda.max_memory_allocated()374# make sure that less than 2.2 GB is allocated375assert mem_bytes < 2.2 * 10**9376377def test_stable_diffusion_pipeline_with_model_offloading(self):378torch.cuda.empty_cache()379torch.cuda.reset_max_memory_allocated()380torch.cuda.reset_peak_memory_stats()381382inputs = self.get_inputs(torch_device, dtype=torch.float16)383384# Normal inference385386pipe = StableDiffusionImg2ImgPipeline.from_pretrained(387"CompVis/stable-diffusion-v1-4",388safety_checker=None,389torch_dtype=torch.float16,390)391pipe.to(torch_device)392pipe.set_progress_bar_config(disable=None)393pipe(**inputs)394mem_bytes = torch.cuda.max_memory_allocated()395396# With model offloading397398# Reload but don't move to cuda399pipe = StableDiffusionImg2ImgPipeline.from_pretrained(400"CompVis/stable-diffusion-v1-4",401safety_checker=None,402torch_dtype=torch.float16,403)404405torch.cuda.empty_cache()406torch.cuda.reset_max_memory_allocated()407torch.cuda.reset_peak_memory_stats()408409pipe.enable_model_cpu_offload()410pipe.set_progress_bar_config(disable=None)411_ = pipe(**inputs)412mem_bytes_offloaded = torch.cuda.max_memory_allocated()413414assert mem_bytes_offloaded < mem_bytes415for module in pipe.text_encoder, pipe.unet, pipe.vae:416assert module.device == torch.device("cpu")417418def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):419init_image = load_image(420"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"421"/img2img/sketch-mountains-input.jpg"422)423# resize to resolution that is divisible by 8 but not 16 or 32424init_image = init_image.resize((760, 504))425426model_id = "CompVis/stable-diffusion-v1-4"427pipe = StableDiffusionImg2ImgPipeline.from_pretrained(428model_id,429safety_checker=None,430)431pipe.to(torch_device)432pipe.set_progress_bar_config(disable=None)433pipe.enable_attention_slicing()434435prompt = "A fantasy landscape, trending on artstation"436437generator = torch.manual_seed(0)438output = pipe(439prompt=prompt,440image=init_image,441strength=0.75,442guidance_scale=7.5,443generator=generator,444output_type="np",445)446image = output.images[0]447448image_slice = image[255:258, 383:386, -1]449450assert image.shape == (504, 760, 3)451expected_slice = np.array([0.9393, 0.9500, 0.9399, 0.9438, 0.9458, 0.9400, 0.9455, 0.9414, 0.9423])452453assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3454455456@nightly457@require_torch_gpu458class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):459def tearDown(self):460super().tearDown()461gc.collect()462torch.cuda.empty_cache()463464def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):465generator = torch.Generator(device=generator_device).manual_seed(seed)466init_image = load_image(467"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"468"/stable_diffusion_img2img/sketch-mountains-input.png"469)470inputs = {471"prompt": "a fantasy landscape, concept art, high resolution",472"image": init_image,473"generator": generator,474"num_inference_steps": 50,475"strength": 0.75,476"guidance_scale": 7.5,477"output_type": "np",478}479return inputs480481def test_img2img_pndm(self):482sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")483sd_pipe.to(torch_device)484sd_pipe.set_progress_bar_config(disable=None)485486inputs = self.get_inputs(torch_device)487image = sd_pipe(**inputs).images[0]488489expected_image = load_numpy(490"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"491"/stable_diffusion_img2img/stable_diffusion_1_5_pndm.npy"492)493max_diff = np.abs(expected_image - image).max()494assert max_diff < 1e-3495496def test_img2img_ddim(self):497sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")498sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)499sd_pipe.to(torch_device)500sd_pipe.set_progress_bar_config(disable=None)501502inputs = self.get_inputs(torch_device)503image = sd_pipe(**inputs).images[0]504505expected_image = load_numpy(506"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"507"/stable_diffusion_img2img/stable_diffusion_1_5_ddim.npy"508)509max_diff = np.abs(expected_image - image).max()510assert max_diff < 1e-3511512def test_img2img_lms(self):513sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")514sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)515sd_pipe.to(torch_device)516sd_pipe.set_progress_bar_config(disable=None)517518inputs = self.get_inputs(torch_device)519image = sd_pipe(**inputs).images[0]520521expected_image = load_numpy(522"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"523"/stable_diffusion_img2img/stable_diffusion_1_5_lms.npy"524)525max_diff = np.abs(expected_image - image).max()526assert max_diff < 1e-3527528def test_img2img_dpm(self):529sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")530sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)531sd_pipe.to(torch_device)532sd_pipe.set_progress_bar_config(disable=None)533534inputs = self.get_inputs(torch_device)535inputs["num_inference_steps"] = 30536image = sd_pipe(**inputs).images[0]537538expected_image = load_numpy(539"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"540"/stable_diffusion_img2img/stable_diffusion_1_5_dpm.npy"541)542max_diff = np.abs(expected_image - image).max()543assert max_diff < 1e-3544545546