Path: blob/main/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py
1450 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from transformers import XLMRobertaTokenizer2223from diffusers import (24AltDiffusionImg2ImgPipeline,25AutoencoderKL,26PNDMScheduler,27UNet2DConditionModel,28)29from diffusers.image_processor import VaeImageProcessor30from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (31RobertaSeriesConfig,32RobertaSeriesModelWithTransformation,33)34from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device35from diffusers.utils.testing_utils import require_torch_gpu363738torch.backends.cuda.matmul.allow_tf32 = False394041class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):42def tearDown(self):43# clean up the VRAM after each test44super().tearDown()45gc.collect()46torch.cuda.empty_cache()4748@property49def dummy_image(self):50batch_size = 151num_channels = 352sizes = (32, 32)5354image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)55return image5657@property58def dummy_cond_unet(self):59torch.manual_seed(0)60model = UNet2DConditionModel(61block_out_channels=(32, 64),62layers_per_block=2,63sample_size=32,64in_channels=4,65out_channels=4,66down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),67up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),68cross_attention_dim=32,69)70return model7172@property73def dummy_vae(self):74torch.manual_seed(0)75model = AutoencoderKL(76block_out_channels=[32, 64],77in_channels=3,78out_channels=3,79down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],80up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],81latent_channels=4,82)83return model8485@property86def dummy_text_encoder(self):87torch.manual_seed(0)88config = RobertaSeriesConfig(89hidden_size=32,90project_dim=32,91intermediate_size=37,92layer_norm_eps=1e-05,93num_attention_heads=4,94num_hidden_layers=5,95pad_token_id=1,96vocab_size=5006,97)98return RobertaSeriesModelWithTransformation(config)99100@property101def dummy_extractor(self):102def extract(*args, **kwargs):103class Out:104def __init__(self):105self.pixel_values = torch.ones([0])106107def to(self, device):108self.pixel_values.to(device)109return self110111return Out()112113return extract114115def test_stable_diffusion_img2img_default_case(self):116device = "cpu" # ensure determinism for the device-dependent torch.Generator117unet = self.dummy_cond_unet118scheduler = PNDMScheduler(skip_prk_steps=True)119vae = self.dummy_vae120bert = self.dummy_text_encoder121tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")122tokenizer.model_max_length = 77123124init_image = self.dummy_image.to(device)125126# make sure here that pndm scheduler skips prk127alt_pipe = AltDiffusionImg2ImgPipeline(128unet=unet,129scheduler=scheduler,130vae=vae,131text_encoder=bert,132tokenizer=tokenizer,133safety_checker=None,134feature_extractor=self.dummy_extractor,135)136alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)137alt_pipe = alt_pipe.to(device)138alt_pipe.set_progress_bar_config(disable=None)139140prompt = "A painting of a squirrel eating a burger"141generator = torch.Generator(device=device).manual_seed(0)142output = alt_pipe(143[prompt],144generator=generator,145guidance_scale=6.0,146num_inference_steps=2,147output_type="np",148image=init_image,149)150151image = output.images152153generator = torch.Generator(device=device).manual_seed(0)154image_from_tuple = alt_pipe(155[prompt],156generator=generator,157guidance_scale=6.0,158num_inference_steps=2,159output_type="np",160image=init_image,161return_dict=False,162)[0]163164image_slice = image[0, -3:, -3:, -1]165image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]166167assert image.shape == (1, 32, 32, 3)168expected_slice = np.array([0.4115, 0.3870, 0.4089, 0.4807, 0.4668, 0.4144, 0.4151, 0.4721, 0.4569])169170assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3171assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-3172173@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")174def test_stable_diffusion_img2img_fp16(self):175"""Test that stable diffusion img2img works with fp16"""176unet = self.dummy_cond_unet177scheduler = PNDMScheduler(skip_prk_steps=True)178vae = self.dummy_vae179bert = self.dummy_text_encoder180tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")181tokenizer.model_max_length = 77182183init_image = self.dummy_image.to(torch_device)184185# put models in fp16186unet = unet.half()187vae = vae.half()188bert = bert.half()189190# make sure here that pndm scheduler skips prk191alt_pipe = AltDiffusionImg2ImgPipeline(192unet=unet,193scheduler=scheduler,194vae=vae,195text_encoder=bert,196tokenizer=tokenizer,197safety_checker=None,198feature_extractor=self.dummy_extractor,199)200alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)201alt_pipe = alt_pipe.to(torch_device)202alt_pipe.set_progress_bar_config(disable=None)203204prompt = "A painting of a squirrel eating a burger"205generator = torch.manual_seed(0)206image = alt_pipe(207[prompt],208generator=generator,209num_inference_steps=2,210output_type="np",211image=init_image,212).images213214assert image.shape == (1, 32, 32, 3)215216@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")217def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):218init_image = load_image(219"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"220"/img2img/sketch-mountains-input.jpg"221)222# resize to resolution that is divisible by 8 but not 16 or 32223init_image = init_image.resize((760, 504))224225model_id = "BAAI/AltDiffusion"226pipe = AltDiffusionImg2ImgPipeline.from_pretrained(227model_id,228safety_checker=None,229)230pipe.to(torch_device)231pipe.set_progress_bar_config(disable=None)232pipe.enable_attention_slicing()233234prompt = "A fantasy landscape, trending on artstation"235236generator = torch.manual_seed(0)237output = pipe(238prompt=prompt,239image=init_image,240strength=0.75,241guidance_scale=7.5,242generator=generator,243output_type="np",244)245image = output.images[0]246247image_slice = image[255:258, 383:386, -1]248249assert image.shape == (504, 760, 3)250expected_slice = np.array([0.9358, 0.9397, 0.9599, 0.9901, 1.0000, 1.0000, 0.9882, 1.0000, 1.0000])251252assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3253254255@slow256@require_torch_gpu257class AltDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):258def tearDown(self):259# clean up the VRAM after each test260super().tearDown()261gc.collect()262torch.cuda.empty_cache()263264def test_stable_diffusion_img2img_pipeline_default(self):265init_image = load_image(266"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"267"/img2img/sketch-mountains-input.jpg"268)269init_image = init_image.resize((768, 512))270expected_image = load_numpy(271"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_alt.npy"272)273274model_id = "BAAI/AltDiffusion"275pipe = AltDiffusionImg2ImgPipeline.from_pretrained(276model_id,277safety_checker=None,278)279pipe.to(torch_device)280pipe.set_progress_bar_config(disable=None)281pipe.enable_attention_slicing()282283prompt = "A fantasy landscape, trending on artstation"284285generator = torch.manual_seed(0)286output = pipe(287prompt=prompt,288image=init_image,289strength=0.75,290guidance_scale=7.5,291generator=generator,292output_type="np",293)294image = output.images[0]295296assert image.shape == (512, 768, 3)297# img2img is flaky across GPUs even in fp32, so using MAE here298assert np.abs(expected_image - image).max() < 1e-3299300301