Path: blob/main/tests/pipelines/dit/test_dit.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import unittest1718import numpy as np19import torch2021from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel22from diffusers.utils import is_xformers_available, load_numpy, slow, torch_device23from diffusers.utils.testing_utils import require_torch_gpu2425from ...pipeline_params import (26CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,27CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS,28)29from ...test_pipelines_common import PipelineTesterMixin303132torch.backends.cuda.matmul.allow_tf32 = False333435class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):36pipeline_class = DiTPipeline37params = CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS38required_optional_params = PipelineTesterMixin.required_optional_params - {39"latents",40"num_images_per_prompt",41"callback",42"callback_steps",43}44batch_params = CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS45test_cpu_offload = False4647def get_dummy_components(self):48torch.manual_seed(0)49transformer = Transformer2DModel(50sample_size=16,51num_layers=2,52patch_size=4,53attention_head_dim=8,54num_attention_heads=2,55in_channels=4,56out_channels=8,57attention_bias=True,58activation_fn="gelu-approximate",59num_embeds_ada_norm=1000,60norm_type="ada_norm_zero",61norm_elementwise_affine=False,62)63vae = AutoencoderKL()64scheduler = DDIMScheduler()65components = {"transformer": transformer.eval(), "vae": vae.eval(), "scheduler": scheduler}66return components6768def get_dummy_inputs(self, device, seed=0):69if str(device).startswith("mps"):70generator = torch.manual_seed(seed)71else:72generator = torch.Generator(device=device).manual_seed(seed)73inputs = {74"class_labels": [1],75"generator": generator,76"num_inference_steps": 2,77"output_type": "numpy",78}79return inputs8081def test_inference(self):82device = "cpu"8384components = self.get_dummy_components()85pipe = self.pipeline_class(**components)86pipe.to(device)87pipe.set_progress_bar_config(disable=None)8889inputs = self.get_dummy_inputs(device)90image = pipe(**inputs).images91image_slice = image[0, -3:, -3:, -1]9293self.assertEqual(image.shape, (1, 16, 16, 3))94expected_slice = np.array([0.4380, 0.4141, 0.5159, 0.0000, 0.4282, 0.6680, 0.5485, 0.2545, 0.6719])95max_diff = np.abs(image_slice.flatten() - expected_slice).max()96self.assertLessEqual(max_diff, 1e-3)9798def test_inference_batch_single_identical(self):99self._test_inference_batch_single_identical(relax_max_difference=True, expected_max_diff=1e-3)100101@unittest.skipIf(102torch_device != "cuda" or not is_xformers_available(),103reason="XFormers attention is only available with CUDA and `xformers` installed",104)105def test_xformers_attention_forwardGenerator_pass(self):106self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)107108109@require_torch_gpu110@slow111class DiTPipelineIntegrationTests(unittest.TestCase):112def tearDown(self):113super().tearDown()114gc.collect()115torch.cuda.empty_cache()116117def test_dit_256(self):118generator = torch.manual_seed(0)119120pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")121pipe.to("cuda")122123words = ["vase", "umbrella", "white shark", "white wolf"]124ids = pipe.get_label_ids(words)125126images = pipe(ids, generator=generator, num_inference_steps=40, output_type="np").images127128for word, image in zip(words, images):129expected_image = load_numpy(130f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy"131)132assert np.abs((expected_image - image).max()) < 1e-2133134def test_dit_512(self):135pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512")136pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)137pipe.to("cuda")138139words = ["vase", "umbrella"]140ids = pipe.get_label_ids(words)141142generator = torch.manual_seed(0)143images = pipe(ids, generator=generator, num_inference_steps=25, output_type="np").images144145for word, image in zip(words, images):146expected_image = load_numpy(147"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"148f"/dit/{word}_512.npy"149)150151assert np.abs((expected_image - image).max()) < 1e-1152153154