Path: blob/main/tests/pipelines/altdiffusion/test_alt_diffusion.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import unittest1718import numpy as np19import torch20from transformers import CLIPTextConfig, CLIPTextModel, XLMRobertaTokenizer2122from diffusers import AltDiffusionPipeline, AutoencoderKL, DDIMScheduler, PNDMScheduler, UNet2DConditionModel23from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (24RobertaSeriesConfig,25RobertaSeriesModelWithTransformation,26)27from diffusers.utils import slow, torch_device28from diffusers.utils.testing_utils import require_torch_gpu2930from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS31from ...test_pipelines_common import PipelineTesterMixin323334torch.backends.cuda.matmul.allow_tf32 = False353637class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):38pipeline_class = AltDiffusionPipeline39params = TEXT_TO_IMAGE_PARAMS40batch_params = TEXT_TO_IMAGE_BATCH_PARAMS4142def get_dummy_components(self):43torch.manual_seed(0)44unet = UNet2DConditionModel(45block_out_channels=(32, 64),46layers_per_block=2,47sample_size=32,48in_channels=4,49out_channels=4,50down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),51up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),52cross_attention_dim=32,53)54scheduler = DDIMScheduler(55beta_start=0.00085,56beta_end=0.012,57beta_schedule="scaled_linear",58clip_sample=False,59set_alpha_to_one=False,60)61torch.manual_seed(0)62vae = AutoencoderKL(63block_out_channels=[32, 64],64in_channels=3,65out_channels=3,66down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],67up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],68latent_channels=4,69)7071# TODO: address the non-deterministic text encoder (fails for save-load tests)72# torch.manual_seed(0)73# text_encoder_config = RobertaSeriesConfig(74# hidden_size=32,75# project_dim=32,76# intermediate_size=37,77# layer_norm_eps=1e-05,78# num_attention_heads=4,79# num_hidden_layers=5,80# vocab_size=5002,81# )82# text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)8384torch.manual_seed(0)85text_encoder_config = CLIPTextConfig(86bos_token_id=0,87eos_token_id=2,88hidden_size=32,89projection_dim=32,90intermediate_size=37,91layer_norm_eps=1e-05,92num_attention_heads=4,93num_hidden_layers=5,94pad_token_id=1,95vocab_size=5002,96)97text_encoder = CLIPTextModel(text_encoder_config)9899tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")100tokenizer.model_max_length = 77101102components = {103"unet": unet,104"scheduler": scheduler,105"vae": vae,106"text_encoder": text_encoder,107"tokenizer": tokenizer,108"safety_checker": None,109"feature_extractor": None,110}111return components112113def get_dummy_inputs(self, device, seed=0):114if str(device).startswith("mps"):115generator = torch.manual_seed(seed)116else:117generator = torch.Generator(device=device).manual_seed(seed)118inputs = {119"prompt": "A painting of a squirrel eating a burger",120"generator": generator,121"num_inference_steps": 2,122"guidance_scale": 6.0,123"output_type": "numpy",124}125return inputs126127def test_alt_diffusion_ddim(self):128device = "cpu" # ensure determinism for the device-dependent torch.Generator129130components = self.get_dummy_components()131torch.manual_seed(0)132text_encoder_config = RobertaSeriesConfig(133hidden_size=32,134project_dim=32,135intermediate_size=37,136layer_norm_eps=1e-05,137num_attention_heads=4,138num_hidden_layers=5,139vocab_size=5002,140)141# TODO: remove after fixing the non-deterministic text encoder142text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)143components["text_encoder"] = text_encoder144145alt_pipe = AltDiffusionPipeline(**components)146alt_pipe = alt_pipe.to(device)147alt_pipe.set_progress_bar_config(disable=None)148149inputs = self.get_dummy_inputs(device)150inputs["prompt"] = "A photo of an astronaut"151output = alt_pipe(**inputs)152image = output.images153image_slice = image[0, -3:, -3:, -1]154155assert image.shape == (1, 64, 64, 3)156expected_slice = np.array(157[0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]158)159160assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2161162def test_alt_diffusion_pndm(self):163device = "cpu" # ensure determinism for the device-dependent torch.Generator164165components = self.get_dummy_components()166components["scheduler"] = PNDMScheduler(skip_prk_steps=True)167torch.manual_seed(0)168text_encoder_config = RobertaSeriesConfig(169hidden_size=32,170project_dim=32,171intermediate_size=37,172layer_norm_eps=1e-05,173num_attention_heads=4,174num_hidden_layers=5,175vocab_size=5002,176)177# TODO: remove after fixing the non-deterministic text encoder178text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)179components["text_encoder"] = text_encoder180alt_pipe = AltDiffusionPipeline(**components)181alt_pipe = alt_pipe.to(device)182alt_pipe.set_progress_bar_config(disable=None)183184inputs = self.get_dummy_inputs(device)185output = alt_pipe(**inputs)186image = output.images187image_slice = image[0, -3:, -3:, -1]188189assert image.shape == (1, 64, 64, 3)190expected_slice = np.array(191[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]192)193194assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2195196197@slow198@require_torch_gpu199class AltDiffusionPipelineIntegrationTests(unittest.TestCase):200def tearDown(self):201# clean up the VRAM after each test202super().tearDown()203gc.collect()204torch.cuda.empty_cache()205206def test_alt_diffusion(self):207# make sure here that pndm scheduler skips prk208alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", safety_checker=None)209alt_pipe = alt_pipe.to(torch_device)210alt_pipe.set_progress_bar_config(disable=None)211212prompt = "A painting of a squirrel eating a burger"213generator = torch.manual_seed(0)214output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")215216image = output.images217218image_slice = image[0, -3:, -3:, -1]219220assert image.shape == (1, 512, 512, 3)221expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586])222223assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2224225def test_alt_diffusion_fast_ddim(self):226scheduler = DDIMScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")227228alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=scheduler, safety_checker=None)229alt_pipe = alt_pipe.to(torch_device)230alt_pipe.set_progress_bar_config(disable=None)231232prompt = "A painting of a squirrel eating a burger"233generator = torch.manual_seed(0)234235output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")236image = output.images237238image_slice = image[0, -3:, -3:, -1]239240assert image.shape == (1, 512, 512, 3)241expected_slice = np.array([0.4019, 0.4052, 0.3810, 0.4119, 0.3916, 0.3982, 0.4651, 0.4195, 0.5323])242243assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2244245246