Path: blob/main/tests/pipelines/stable_diffusion/test_cycle_diffusion.py
1450 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2223from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet2DConditionModel24from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device25from diffusers.utils.testing_utils import require_torch_gpu, skip_mps2627from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS28from ...test_pipelines_common import PipelineTesterMixin293031torch.backends.cuda.matmul.allow_tf32 = False323334class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):35pipeline_class = CycleDiffusionPipeline36params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {37"negative_prompt",38"height",39"width",40"negative_prompt_embeds",41}42required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}43batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})4445def get_dummy_components(self):46torch.manual_seed(0)47unet = UNet2DConditionModel(48block_out_channels=(32, 64),49layers_per_block=2,50sample_size=32,51in_channels=4,52out_channels=4,53down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),54up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),55cross_attention_dim=32,56)57scheduler = DDIMScheduler(58beta_start=0.00085,59beta_end=0.012,60beta_schedule="scaled_linear",61num_train_timesteps=1000,62clip_sample=False,63set_alpha_to_one=False,64)65torch.manual_seed(0)66vae = AutoencoderKL(67block_out_channels=[32, 64],68in_channels=3,69out_channels=3,70down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],71up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],72latent_channels=4,73)74torch.manual_seed(0)75text_encoder_config = CLIPTextConfig(76bos_token_id=0,77eos_token_id=2,78hidden_size=32,79intermediate_size=37,80layer_norm_eps=1e-05,81num_attention_heads=4,82num_hidden_layers=5,83pad_token_id=1,84vocab_size=1000,85)86text_encoder = CLIPTextModel(text_encoder_config)87tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")8889components = {90"unet": unet,91"scheduler": scheduler,92"vae": vae,93"text_encoder": text_encoder,94"tokenizer": tokenizer,95"safety_checker": None,96"feature_extractor": None,97}98return components99100def get_dummy_inputs(self, device, seed=0):101image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)102if str(device).startswith("mps"):103generator = torch.manual_seed(seed)104else:105generator = torch.Generator(device=device).manual_seed(seed)106inputs = {107"prompt": "An astronaut riding an elephant",108"source_prompt": "An astronaut riding a horse",109"image": image,110"generator": generator,111"num_inference_steps": 2,112"eta": 0.1,113"strength": 0.8,114"guidance_scale": 3,115"source_guidance_scale": 1,116"output_type": "numpy",117}118return inputs119120def test_stable_diffusion_cycle(self):121device = "cpu" # ensure determinism for the device-dependent torch.Generator122123components = self.get_dummy_components()124pipe = CycleDiffusionPipeline(**components)125pipe = pipe.to(device)126pipe.set_progress_bar_config(disable=None)127128inputs = self.get_dummy_inputs(device)129output = pipe(**inputs)130images = output.images131132image_slice = images[0, -3:, -3:, -1]133134assert images.shape == (1, 32, 32, 3)135expected_slice = np.array([0.4459, 0.4943, 0.4544, 0.6643, 0.5474, 0.4327, 0.5701, 0.5959, 0.5179])136137assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2138139@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")140def test_stable_diffusion_cycle_fp16(self):141components = self.get_dummy_components()142for name, module in components.items():143if hasattr(module, "half"):144components[name] = module.half()145pipe = CycleDiffusionPipeline(**components)146pipe = pipe.to(torch_device)147pipe.set_progress_bar_config(disable=None)148149inputs = self.get_dummy_inputs(torch_device)150output = pipe(**inputs)151images = output.images152153image_slice = images[0, -3:, -3:, -1]154155assert images.shape == (1, 32, 32, 3)156expected_slice = np.array([0.3506, 0.4543, 0.446, 0.4575, 0.5195, 0.4155, 0.5273, 0.518, 0.4116])157158assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2159160@skip_mps161def test_save_load_local(self):162return super().test_save_load_local()163164@unittest.skip("non-deterministic pipeline")165def test_inference_batch_single_identical(self):166return super().test_inference_batch_single_identical()167168@skip_mps169def test_dict_tuple_outputs_equivalent(self):170return super().test_dict_tuple_outputs_equivalent()171172@skip_mps173def test_save_load_optional_components(self):174return super().test_save_load_optional_components()175176@skip_mps177def test_attention_slicing_forward_pass(self):178return super().test_attention_slicing_forward_pass()179180181@slow182@require_torch_gpu183class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):184def tearDown(self):185# clean up the VRAM after each test186super().tearDown()187gc.collect()188torch.cuda.empty_cache()189190def test_cycle_diffusion_pipeline_fp16(self):191init_image = load_image(192"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"193"/cycle-diffusion/black_colored_car.png"194)195expected_image = load_numpy(196"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car_fp16.npy"197)198init_image = init_image.resize((512, 512))199200model_id = "CompVis/stable-diffusion-v1-4"201scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")202pipe = CycleDiffusionPipeline.from_pretrained(203model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"204)205206pipe.to(torch_device)207pipe.set_progress_bar_config(disable=None)208pipe.enable_attention_slicing()209210source_prompt = "A black colored car"211prompt = "A blue colored car"212213generator = torch.manual_seed(0)214output = pipe(215prompt=prompt,216source_prompt=source_prompt,217image=init_image,218num_inference_steps=100,219eta=0.1,220strength=0.85,221guidance_scale=3,222source_guidance_scale=1,223generator=generator,224output_type="np",225)226image = output.images227228# the values aren't exactly equal, but the images look the same visually229assert np.abs(image - expected_image).max() < 5e-1230231def test_cycle_diffusion_pipeline(self):232init_image = load_image(233"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"234"/cycle-diffusion/black_colored_car.png"235)236expected_image = load_numpy(237"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car.npy"238)239init_image = init_image.resize((512, 512))240241model_id = "CompVis/stable-diffusion-v1-4"242scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")243pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)244245pipe.to(torch_device)246pipe.set_progress_bar_config(disable=None)247pipe.enable_attention_slicing()248249source_prompt = "A black colored car"250prompt = "A blue colored car"251252generator = torch.manual_seed(0)253output = pipe(254prompt=prompt,255source_prompt=source_prompt,256image=init_image,257num_inference_steps=100,258eta=0.1,259strength=0.85,260guidance_scale=3,261source_guidance_scale=1,262generator=generator,263output_type="np",264)265image = output.images266267assert np.abs(image - expected_image).max() < 1e-2268269270