Path: blob/main/tests/pipelines/paint_by_example/test_paint_by_example.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from PIL import Image22from transformers import CLIPImageProcessor, CLIPVisionConfig2324from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel25from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder26from diffusers.utils import floats_tensor, load_image, slow, torch_device27from diffusers.utils.testing_utils import require_torch_gpu2829from ...pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS30from ...test_pipelines_common import PipelineTesterMixin313233torch.backends.cuda.matmul.allow_tf32 = False343536class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):37pipeline_class = PaintByExamplePipeline38params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS39batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS4041def get_dummy_components(self):42torch.manual_seed(0)43unet = UNet2DConditionModel(44block_out_channels=(32, 64),45layers_per_block=2,46sample_size=32,47in_channels=9,48out_channels=4,49down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),50up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),51cross_attention_dim=32,52)53scheduler = PNDMScheduler(skip_prk_steps=True)54torch.manual_seed(0)55vae = AutoencoderKL(56block_out_channels=[32, 64],57in_channels=3,58out_channels=3,59down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],60up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],61latent_channels=4,62)63torch.manual_seed(0)64config = CLIPVisionConfig(65hidden_size=32,66projection_dim=32,67intermediate_size=37,68layer_norm_eps=1e-05,69num_attention_heads=4,70num_hidden_layers=5,71image_size=32,72patch_size=4,73)74image_encoder = PaintByExampleImageEncoder(config, proj_size=32)75feature_extractor = CLIPImageProcessor(crop_size=32, size=32)7677components = {78"unet": unet,79"scheduler": scheduler,80"vae": vae,81"image_encoder": image_encoder,82"safety_checker": None,83"feature_extractor": feature_extractor,84}85return components8687def convert_to_pt(self, image):88image = np.array(image.convert("RGB"))89image = image[None].transpose(0, 3, 1, 2)90image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.091return image9293def get_dummy_inputs(self, device="cpu", seed=0):94# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched95image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)96image = image.cpu().permute(0, 2, 3, 1)[0]97init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))98mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))99example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))100101if str(device).startswith("mps"):102generator = torch.manual_seed(seed)103else:104generator = torch.Generator(device=device).manual_seed(seed)105inputs = {106"example_image": example_image,107"image": init_image,108"mask_image": mask_image,109"generator": generator,110"num_inference_steps": 2,111"guidance_scale": 6.0,112"output_type": "numpy",113}114return inputs115116def test_paint_by_example_inpaint(self):117components = self.get_dummy_components()118119# make sure here that pndm scheduler skips prk120pipe = PaintByExamplePipeline(**components)121pipe = pipe.to("cpu")122pipe.set_progress_bar_config(disable=None)123124inputs = self.get_dummy_inputs()125output = pipe(**inputs)126image = output.images127128image_slice = image[0, -3:, -3:, -1]129130assert image.shape == (1, 64, 64, 3)131expected_slice = np.array([0.4701, 0.5555, 0.3994, 0.5107, 0.5691, 0.4517, 0.5125, 0.4769, 0.4539])132133assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2134135def test_paint_by_example_image_tensor(self):136device = "cpu"137inputs = self.get_dummy_inputs()138inputs.pop("mask_image")139image = self.convert_to_pt(inputs.pop("image"))140mask_image = image.clamp(0, 1) / 2141142# make sure here that pndm scheduler skips prk143pipe = PaintByExamplePipeline(**self.get_dummy_components())144pipe = pipe.to(device)145pipe.set_progress_bar_config(disable=None)146147output = pipe(image=image, mask_image=mask_image[:, 0], **inputs)148out_1 = output.images149150image = image.cpu().permute(0, 2, 3, 1)[0]151mask_image = mask_image.cpu().permute(0, 2, 3, 1)[0]152153image = Image.fromarray(np.uint8(image)).convert("RGB")154mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB")155156output = pipe(**self.get_dummy_inputs())157out_2 = output.images158159assert out_1.shape == (1, 64, 64, 3)160assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2161162163@slow164@require_torch_gpu165class PaintByExamplePipelineIntegrationTests(unittest.TestCase):166def tearDown(self):167# clean up the VRAM after each test168super().tearDown()169gc.collect()170torch.cuda.empty_cache()171172def test_paint_by_example(self):173# make sure here that pndm scheduler skips prk174init_image = load_image(175"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"176"/paint_by_example/dog_in_bucket.png"177)178mask_image = load_image(179"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"180"/paint_by_example/mask.png"181)182example_image = load_image(183"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"184"/paint_by_example/panda.jpg"185)186187pipe = PaintByExamplePipeline.from_pretrained("Fantasy-Studio/Paint-by-Example")188pipe = pipe.to(torch_device)189pipe.set_progress_bar_config(disable=None)190191generator = torch.manual_seed(321)192output = pipe(193image=init_image,194mask_image=mask_image,195example_image=example_image,196generator=generator,197guidance_scale=5.0,198num_inference_steps=50,199output_type="np",200)201202image = output.images203204image_slice = image[0, -3:, -3:, -1]205206assert image.shape == (1, 512, 512, 3)207expected_slice = np.array([0.4834, 0.4811, 0.4874, 0.5122, 0.5081, 0.5144, 0.5291, 0.5290, 0.5374])208209assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2210211212