Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from PIL import Image22from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2324from diffusers import (25AutoencoderKL,26DDIMScheduler,27EulerAncestralDiscreteScheduler,28LMSDiscreteScheduler,29PNDMScheduler,30StableDiffusionInstructPix2PixPipeline,31UNet2DConditionModel,32)33from diffusers.utils import floats_tensor, load_image, slow, torch_device34from diffusers.utils.testing_utils import require_torch_gpu3536from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS37from ...test_pipelines_common import PipelineTesterMixin383940torch.backends.cuda.matmul.allow_tf32 = False414243class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unittest.TestCase):44pipeline_class = StableDiffusionInstructPix2PixPipeline45params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}46batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS4748def get_dummy_components(self):49torch.manual_seed(0)50unet = UNet2DConditionModel(51block_out_channels=(32, 64),52layers_per_block=2,53sample_size=32,54in_channels=8,55out_channels=4,56down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),57up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),58cross_attention_dim=32,59)60scheduler = PNDMScheduler(skip_prk_steps=True)61torch.manual_seed(0)62vae = AutoencoderKL(63block_out_channels=[32, 64],64in_channels=3,65out_channels=3,66down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],67up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],68latent_channels=4,69)70torch.manual_seed(0)71text_encoder_config = CLIPTextConfig(72bos_token_id=0,73eos_token_id=2,74hidden_size=32,75intermediate_size=37,76layer_norm_eps=1e-05,77num_attention_heads=4,78num_hidden_layers=5,79pad_token_id=1,80vocab_size=1000,81)82text_encoder = CLIPTextModel(text_encoder_config)83tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")8485components = {86"unet": unet,87"scheduler": scheduler,88"vae": vae,89"text_encoder": text_encoder,90"tokenizer": tokenizer,91"safety_checker": None,92"feature_extractor": None,93}94return components9596def get_dummy_inputs(self, device, seed=0):97image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)98image = image.cpu().permute(0, 2, 3, 1)[0]99image = Image.fromarray(np.uint8(image)).convert("RGB")100if str(device).startswith("mps"):101generator = torch.manual_seed(seed)102else:103generator = torch.Generator(device=device).manual_seed(seed)104inputs = {105"prompt": "A painting of a squirrel eating a burger",106"image": image,107"generator": generator,108"num_inference_steps": 2,109"guidance_scale": 6.0,110"image_guidance_scale": 1,111"output_type": "numpy",112}113return inputs114115def test_stable_diffusion_pix2pix_default_case(self):116device = "cpu" # ensure determinism for the device-dependent torch.Generator117components = self.get_dummy_components()118sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)119sd_pipe = sd_pipe.to(device)120sd_pipe.set_progress_bar_config(disable=None)121122inputs = self.get_dummy_inputs(device)123image = sd_pipe(**inputs).images124image_slice = image[0, -3:, -3:, -1]125assert image.shape == (1, 32, 32, 3)126expected_slice = np.array([0.7318, 0.3723, 0.4662, 0.623, 0.5770, 0.5014, 0.4281, 0.5550, 0.4813])127128assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3129130def test_stable_diffusion_pix2pix_negative_prompt(self):131device = "cpu" # ensure determinism for the device-dependent torch.Generator132components = self.get_dummy_components()133sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)134sd_pipe = sd_pipe.to(device)135sd_pipe.set_progress_bar_config(disable=None)136137inputs = self.get_dummy_inputs(device)138negative_prompt = "french fries"139output = sd_pipe(**inputs, negative_prompt=negative_prompt)140image = output.images141image_slice = image[0, -3:, -3:, -1]142143assert image.shape == (1, 32, 32, 3)144expected_slice = np.array([0.7323, 0.3688, 0.4611, 0.6255, 0.5746, 0.5017, 0.433, 0.5553, 0.4827])145146assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3147148def test_stable_diffusion_pix2pix_multiple_init_images(self):149device = "cpu" # ensure determinism for the device-dependent torch.Generator150components = self.get_dummy_components()151sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)152sd_pipe = sd_pipe.to(device)153sd_pipe.set_progress_bar_config(disable=None)154155inputs = self.get_dummy_inputs(device)156inputs["prompt"] = [inputs["prompt"]] * 2157158image = np.array(inputs["image"]).astype(np.float32) / 255.0159image = torch.from_numpy(image).unsqueeze(0).to(device)160image = image.permute(0, 3, 1, 2)161inputs["image"] = image.repeat(2, 1, 1, 1)162163image = sd_pipe(**inputs).images164image_slice = image[-1, -3:, -3:, -1]165166assert image.shape == (2, 32, 32, 3)167expected_slice = np.array([0.606, 0.5712, 0.5099, 0.598, 0.5805, 0.7205, 0.6793, 0.554, 0.5607])168169assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3170171def test_stable_diffusion_pix2pix_euler(self):172device = "cpu" # ensure determinism for the device-dependent torch.Generator173components = self.get_dummy_components()174components["scheduler"] = EulerAncestralDiscreteScheduler(175beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"176)177sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)178sd_pipe = sd_pipe.to(device)179sd_pipe.set_progress_bar_config(disable=None)180181inputs = self.get_dummy_inputs(device)182image = sd_pipe(**inputs).images183image_slice = image[0, -3:, -3:, -1]184185slice = [round(x, 4) for x in image_slice.flatten().tolist()]186print(",".join([str(x) for x in slice]))187188assert image.shape == (1, 32, 32, 3)189expected_slice = np.array([0.726, 0.3902, 0.4868, 0.585, 0.5672, 0.511, 0.3906, 0.551, 0.4846])190191assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3192193194@slow195@require_torch_gpu196class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):197def tearDown(self):198super().tearDown()199gc.collect()200torch.cuda.empty_cache()201202def get_inputs(self, seed=0):203generator = torch.manual_seed(seed)204image = load_image(205"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_pix2pix/example.jpg"206)207inputs = {208"prompt": "turn him into a cyborg",209"image": image,210"generator": generator,211"num_inference_steps": 3,212"guidance_scale": 7.5,213"image_guidance_scale": 1.0,214"output_type": "numpy",215}216return inputs217218def test_stable_diffusion_pix2pix_default(self):219pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(220"timbrooks/instruct-pix2pix", safety_checker=None221)222pipe.to(torch_device)223pipe.set_progress_bar_config(disable=None)224pipe.enable_attention_slicing()225226inputs = self.get_inputs()227image = pipe(**inputs).images228image_slice = image[0, -3:, -3:, -1].flatten()229230assert image.shape == (1, 512, 512, 3)231expected_slice = np.array([0.5902, 0.6015, 0.6027, 0.5983, 0.6092, 0.6061, 0.5765, 0.5785, 0.5555])232233assert np.abs(expected_slice - image_slice).max() < 1e-3234235def test_stable_diffusion_pix2pix_k_lms(self):236pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(237"timbrooks/instruct-pix2pix", safety_checker=None238)239pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)240pipe.to(torch_device)241pipe.set_progress_bar_config(disable=None)242pipe.enable_attention_slicing()243244inputs = self.get_inputs()245image = pipe(**inputs).images246image_slice = image[0, -3:, -3:, -1].flatten()247248assert image.shape == (1, 512, 512, 3)249expected_slice = np.array([0.6578, 0.6817, 0.6972, 0.6761, 0.6856, 0.6916, 0.6428, 0.6516, 0.6301])250251assert np.abs(expected_slice - image_slice).max() < 1e-3252253def test_stable_diffusion_pix2pix_ddim(self):254pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(255"timbrooks/instruct-pix2pix", safety_checker=None256)257pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)258pipe.to(torch_device)259pipe.set_progress_bar_config(disable=None)260pipe.enable_attention_slicing()261262inputs = self.get_inputs()263image = pipe(**inputs).images264image_slice = image[0, -3:, -3:, -1].flatten()265266assert image.shape == (1, 512, 512, 3)267expected_slice = np.array([0.3828, 0.3834, 0.3818, 0.3792, 0.3865, 0.3752, 0.3792, 0.3847, 0.3753])268269assert np.abs(expected_slice - image_slice).max() < 1e-3270271def test_stable_diffusion_pix2pix_intermediate_state(self):272number_of_steps = 0273274def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:275callback_fn.has_been_called = True276nonlocal number_of_steps277number_of_steps += 1278if step == 1:279latents = latents.detach().cpu().numpy()280assert latents.shape == (1, 4, 64, 64)281latents_slice = latents[0, -3:, -3:, -1]282expected_slice = np.array([-0.2463, -0.4644, -0.9756, 1.5176, 1.4414, 0.7866, 0.9897, 0.8521, 0.7983])283284assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2285elif step == 2:286latents = latents.detach().cpu().numpy()287assert latents.shape == (1, 4, 64, 64)288latents_slice = latents[0, -3:, -3:, -1]289expected_slice = np.array([-0.2644, -0.4626, -0.9653, 1.5176, 1.4551, 0.7686, 0.9805, 0.8452, 0.8115])290291assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2292293callback_fn.has_been_called = False294295pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(296"timbrooks/instruct-pix2pix", safety_checker=None, torch_dtype=torch.float16297)298pipe = pipe.to(torch_device)299pipe.set_progress_bar_config(disable=None)300pipe.enable_attention_slicing()301302inputs = self.get_inputs()303pipe(**inputs, callback=callback_fn, callback_steps=1)304assert callback_fn.has_been_called305assert number_of_steps == 3306307def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):308torch.cuda.empty_cache()309torch.cuda.reset_max_memory_allocated()310torch.cuda.reset_peak_memory_stats()311312pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(313"timbrooks/instruct-pix2pix", safety_checker=None, torch_dtype=torch.float16314)315pipe = pipe.to(torch_device)316pipe.set_progress_bar_config(disable=None)317pipe.enable_attention_slicing(1)318pipe.enable_sequential_cpu_offload()319320inputs = self.get_inputs()321_ = pipe(**inputs)322323mem_bytes = torch.cuda.max_memory_allocated()324# make sure that less than 2.2 GB is allocated325assert mem_bytes < 2.2 * 10**9326327def test_stable_diffusion_pix2pix_pipeline_multiple_of_8(self):328inputs = self.get_inputs()329# resize to resolution that is divisible by 8 but not 16 or 32330inputs["image"] = inputs["image"].resize((504, 504))331332model_id = "timbrooks/instruct-pix2pix"333pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(334model_id,335safety_checker=None,336)337pipe.to(torch_device)338pipe.set_progress_bar_config(disable=None)339pipe.enable_attention_slicing()340341output = pipe(**inputs)342image = output.images[0]343344image_slice = image[255:258, 383:386, -1]345346assert image.shape == (504, 504, 3)347expected_slice = np.array([0.2726, 0.2529, 0.2664, 0.2655, 0.2641, 0.2642, 0.2591, 0.2649, 0.2590])348349assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3350351352