Path: blob/main/tests/pipelines/ddpm/test_ddpm.py
1450 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import unittest1617import numpy as np18import torch1920from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel21from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device222324torch.backends.cuda.matmul.allow_tf32 = False252627class DDPMPipelineFastTests(unittest.TestCase):28@property29def dummy_uncond_unet(self):30torch.manual_seed(0)31model = UNet2DModel(32block_out_channels=(32, 64),33layers_per_block=2,34sample_size=32,35in_channels=3,36out_channels=3,37down_block_types=("DownBlock2D", "AttnDownBlock2D"),38up_block_types=("AttnUpBlock2D", "UpBlock2D"),39)40return model4142def test_fast_inference(self):43device = "cpu"44unet = self.dummy_uncond_unet45scheduler = DDPMScheduler()4647ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)48ddpm.to(device)49ddpm.set_progress_bar_config(disable=None)5051generator = torch.Generator(device=device).manual_seed(0)52image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images5354generator = torch.Generator(device=device).manual_seed(0)55image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]5657image_slice = image[0, -3:, -3:, -1]58image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]5960assert image.shape == (1, 32, 32, 3)61expected_slice = np.array(62[9.956e-01, 5.785e-01, 4.675e-01, 9.930e-01, 0.0, 1.000, 1.199e-03, 2.648e-04, 5.101e-04]63)6465assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-266assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-26768def test_inference_predict_sample(self):69unet = self.dummy_uncond_unet70scheduler = DDPMScheduler(prediction_type="sample")7172ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)73ddpm.to(torch_device)74ddpm.set_progress_bar_config(disable=None)7576generator = torch.manual_seed(0)77image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images7879generator = torch.manual_seed(0)80image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0]8182image_slice = image[0, -3:, -3:, -1]83image_eps_slice = image_eps[0, -3:, -3:, -1]8485assert image.shape == (1, 32, 32, 3)86tolerance = 1e-2 if torch_device != "mps" else 3e-287assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance888990@slow91@require_torch_gpu92class DDPMPipelineIntegrationTests(unittest.TestCase):93def test_inference_cifar10(self):94model_id = "google/ddpm-cifar10-32"9596unet = UNet2DModel.from_pretrained(model_id)97scheduler = DDPMScheduler.from_pretrained(model_id)9899ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)100ddpm.to(torch_device)101ddpm.set_progress_bar_config(disable=None)102103generator = torch.manual_seed(0)104image = ddpm(generator=generator, output_type="numpy").images105106image_slice = image[0, -3:, -3:, -1]107108assert image.shape == (1, 32, 32, 3)109expected_slice = np.array([0.4200, 0.3588, 0.1939, 0.3847, 0.3382, 0.2647, 0.4155, 0.3582, 0.3385])110assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2111112113