Path: blob/main/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
1448 views
import gc1import random2import unittest34import torch5from transformers import (6CLIPImageProcessor,7CLIPTextConfig,8CLIPTextModel,9CLIPTokenizer,10CLIPVisionConfig,11CLIPVisionModelWithProjection,12)1314from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableUnCLIPImg2ImgPipeline, UNet2DConditionModel15from diffusers.pipelines.pipeline_utils import DiffusionPipeline16from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer17from diffusers.utils.import_utils import is_xformers_available18from diffusers.utils.testing_utils import floats_tensor, load_image, load_numpy, require_torch_gpu, slow, torch_device1920from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS21from ...test_pipelines_common import (22PipelineTesterMixin,23assert_mean_pixel_difference,24)252627class StableUnCLIPImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):28pipeline_class = StableUnCLIPImg2ImgPipeline29params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS30batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS3132def get_dummy_components(self):33embedder_hidden_size = 3234embedder_projection_dim = embedder_hidden_size3536# image encoding components3738feature_extractor = CLIPImageProcessor(crop_size=32, size=32)3940image_encoder = CLIPVisionModelWithProjection(41CLIPVisionConfig(42hidden_size=embedder_hidden_size,43projection_dim=embedder_projection_dim,44num_hidden_layers=5,45num_attention_heads=4,46image_size=32,47intermediate_size=37,48patch_size=1,49)50)5152# regular denoising components5354torch.manual_seed(0)55image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedder_hidden_size)56image_noising_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2")5758torch.manual_seed(0)59tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")6061torch.manual_seed(0)62text_encoder = CLIPTextModel(63CLIPTextConfig(64bos_token_id=0,65eos_token_id=2,66hidden_size=embedder_hidden_size,67projection_dim=32,68intermediate_size=37,69layer_norm_eps=1e-05,70num_attention_heads=4,71num_hidden_layers=5,72pad_token_id=1,73vocab_size=1000,74)75)7677torch.manual_seed(0)78unet = UNet2DConditionModel(79sample_size=32,80in_channels=4,81out_channels=4,82down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),83up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"),84block_out_channels=(32, 64),85attention_head_dim=(2, 4),86class_embed_type="projection",87# The class embeddings are the noise augmented image embeddings.88# I.e. the image embeddings concated with the noised embeddings of the same dimension89projection_class_embeddings_input_dim=embedder_projection_dim * 2,90cross_attention_dim=embedder_hidden_size,91layers_per_block=1,92upcast_attention=True,93use_linear_projection=True,94)9596torch.manual_seed(0)97scheduler = DDIMScheduler(98beta_schedule="scaled_linear",99beta_start=0.00085,100beta_end=0.012,101prediction_type="v_prediction",102set_alpha_to_one=False,103steps_offset=1,104)105106torch.manual_seed(0)107vae = AutoencoderKL()108109components = {110# image encoding components111"feature_extractor": feature_extractor,112"image_encoder": image_encoder,113# image noising components114"image_normalizer": image_normalizer,115"image_noising_scheduler": image_noising_scheduler,116# regular denoising components117"tokenizer": tokenizer,118"text_encoder": text_encoder,119"unet": unet,120"scheduler": scheduler,121"vae": vae,122}123124return components125126def get_dummy_inputs(self, device, seed=0, pil_image=True):127if str(device).startswith("mps"):128generator = torch.manual_seed(seed)129else:130generator = torch.Generator(device=device).manual_seed(seed)131132input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)133134if pil_image:135input_image = input_image * 0.5 + 0.5136input_image = input_image.clamp(0, 1)137input_image = input_image.cpu().permute(0, 2, 3, 1).float().numpy()138input_image = DiffusionPipeline.numpy_to_pil(input_image)[0]139140return {141"prompt": "An anime racoon running a marathon",142"image": input_image,143"generator": generator,144"num_inference_steps": 2,145"output_type": "np",146}147148# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass149# because GPU undeterminism requires a looser check.150def test_attention_slicing_forward_pass(self):151test_max_difference = torch_device in ["cpu", "mps"]152153self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)154155# Overriding PipelineTesterMixin::test_inference_batch_single_identical156# because undeterminism requires a looser check.157def test_inference_batch_single_identical(self):158test_max_difference = torch_device in ["cpu", "mps"]159160self._test_inference_batch_single_identical(test_max_difference=test_max_difference)161162@unittest.skipIf(163torch_device != "cuda" or not is_xformers_available(),164reason="XFormers attention is only available with CUDA and `xformers` installed",165)166def test_xformers_attention_forwardGenerator_pass(self):167self._test_xformers_attention_forwardGenerator_pass(test_max_difference=False)168169170@slow171@require_torch_gpu172class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase):173def tearDown(self):174# clean up the VRAM after each test175super().tearDown()176gc.collect()177torch.cuda.empty_cache()178179def test_stable_unclip_l_img2img(self):180input_image = load_image(181"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/turtle.png"182)183184expected_image = load_numpy(185"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/stable_unclip_2_1_l_img2img_anime_turtle_fp16.npy"186)187188pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(189"fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16190)191pipe.to(torch_device)192pipe.set_progress_bar_config(disable=None)193# stable unclip will oom when integration tests are run on a V100,194# so turn on memory savings195pipe.enable_attention_slicing()196pipe.enable_sequential_cpu_offload()197198generator = torch.Generator(device="cpu").manual_seed(0)199output = pipe("anime turle", image=input_image, generator=generator, output_type="np")200201image = output.images[0]202203assert image.shape == (768, 768, 3)204205assert_mean_pixel_difference(image, expected_image)206207def test_stable_unclip_h_img2img(self):208input_image = load_image(209"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/turtle.png"210)211212expected_image = load_numpy(213"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/stable_unclip_2_1_h_img2img_anime_turtle_fp16.npy"214)215216pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(217"fusing/stable-unclip-2-1-h-img2img", torch_dtype=torch.float16218)219pipe.to(torch_device)220pipe.set_progress_bar_config(disable=None)221# stable unclip will oom when integration tests are run on a V100,222# so turn on memory savings223pipe.enable_attention_slicing()224pipe.enable_sequential_cpu_offload()225226generator = torch.Generator(device="cpu").manual_seed(0)227output = pipe("anime turle", image=input_image, generator=generator, output_type="np")228229image = output.images[0]230231assert image.shape == (768, 768, 3)232233assert_mean_pixel_difference(image, expected_image)234235def test_stable_unclip_img2img_pipeline_with_sequential_cpu_offloading(self):236input_image = load_image(237"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/turtle.png"238)239240torch.cuda.empty_cache()241torch.cuda.reset_max_memory_allocated()242torch.cuda.reset_peak_memory_stats()243244pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(245"fusing/stable-unclip-2-1-h-img2img", torch_dtype=torch.float16246)247pipe = pipe.to(torch_device)248pipe.set_progress_bar_config(disable=None)249pipe.enable_attention_slicing()250pipe.enable_sequential_cpu_offload()251252_ = pipe(253"anime turtle",254image=input_image,255num_inference_steps=2,256output_type="np",257)258259mem_bytes = torch.cuda.max_memory_allocated()260# make sure that less than 7 GB is allocated261assert mem_bytes < 7 * 10**9262263264