Path: blob/main/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import random16import unittest1718import numpy as np19import torch2021from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel22from diffusers.utils import PIL_INTERPOLATION, floats_tensor, load_image, slow, torch_device23from diffusers.utils.testing_utils import require_torch242526torch.backends.cuda.matmul.allow_tf32 = False272829class LDMSuperResolutionPipelineFastTests(unittest.TestCase):30@property31def dummy_image(self):32batch_size = 133num_channels = 334sizes = (32, 32)3536image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)37return image3839@property40def dummy_uncond_unet(self):41torch.manual_seed(0)42model = UNet2DModel(43block_out_channels=(32, 64),44layers_per_block=2,45sample_size=32,46in_channels=6,47out_channels=3,48down_block_types=("DownBlock2D", "AttnDownBlock2D"),49up_block_types=("AttnUpBlock2D", "UpBlock2D"),50)51return model5253@property54def dummy_vq_model(self):55torch.manual_seed(0)56model = VQModel(57block_out_channels=[32, 64],58in_channels=3,59out_channels=3,60down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],61up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],62latent_channels=3,63)64return model6566def test_inference_superresolution(self):67device = "cpu"68unet = self.dummy_uncond_unet69scheduler = DDIMScheduler()70vqvae = self.dummy_vq_model7172ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)73ldm.to(device)74ldm.set_progress_bar_config(disable=None)7576init_image = self.dummy_image.to(device)7778generator = torch.Generator(device=device).manual_seed(0)79image = ldm(image=init_image, generator=generator, num_inference_steps=2, output_type="numpy").images8081image_slice = image[0, -3:, -3:, -1]8283assert image.shape == (1, 64, 64, 3)84expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])8586assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-28788@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")89def test_inference_superresolution_fp16(self):90unet = self.dummy_uncond_unet91scheduler = DDIMScheduler()92vqvae = self.dummy_vq_model9394# put models in fp1695unet = unet.half()96vqvae = vqvae.half()9798ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)99ldm.to(torch_device)100ldm.set_progress_bar_config(disable=None)101102init_image = self.dummy_image.to(torch_device)103104image = ldm(init_image, num_inference_steps=2, output_type="numpy").images105106assert image.shape == (1, 64, 64, 3)107108109@slow110@require_torch111class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):112def test_inference_superresolution(self):113init_image = load_image(114"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"115"/vq_diffusion/teddy_bear_pool.png"116)117init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"])118119ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")120ldm.set_progress_bar_config(disable=None)121122generator = torch.manual_seed(0)123image = ldm(image=init_image, generator=generator, num_inference_steps=20, output_type="numpy").images124125image_slice = image[0, -3:, -3:, -1]126127assert image.shape == (1, 256, 256, 3)128expected_slice = np.array([0.7644, 0.7679, 0.7642, 0.7633, 0.7666, 0.7560, 0.7425, 0.7257, 0.6907])129130assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2131132133