Path: blob/main/tests/pipelines/stable_unclip/test_stable_unclip.py
1451 views
import gc1import unittest23import torch4from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer56from diffusers import (7AutoencoderKL,8DDIMScheduler,9DDPMScheduler,10PriorTransformer,11StableUnCLIPPipeline,12UNet2DConditionModel,13)14from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer15from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, slow, torch_device1617from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS18from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference192021class StableUnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):22pipeline_class = StableUnCLIPPipeline23params = TEXT_TO_IMAGE_PARAMS24batch_params = TEXT_TO_IMAGE_BATCH_PARAMS2526# TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false27test_xformers_attention = False2829def get_dummy_components(self):30embedder_hidden_size = 3231embedder_projection_dim = embedder_hidden_size3233# prior components3435torch.manual_seed(0)36prior_tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")3738torch.manual_seed(0)39prior_text_encoder = CLIPTextModelWithProjection(40CLIPTextConfig(41bos_token_id=0,42eos_token_id=2,43hidden_size=embedder_hidden_size,44projection_dim=embedder_projection_dim,45intermediate_size=37,46layer_norm_eps=1e-05,47num_attention_heads=4,48num_hidden_layers=5,49pad_token_id=1,50vocab_size=1000,51)52)5354torch.manual_seed(0)55prior = PriorTransformer(56num_attention_heads=2,57attention_head_dim=12,58embedding_dim=embedder_projection_dim,59num_layers=1,60)6162torch.manual_seed(0)63prior_scheduler = DDPMScheduler(64variance_type="fixed_small_log",65prediction_type="sample",66num_train_timesteps=1000,67clip_sample=True,68clip_sample_range=5.0,69beta_schedule="squaredcos_cap_v2",70)7172# regular denoising components7374torch.manual_seed(0)75image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedder_hidden_size)76image_noising_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2")7778torch.manual_seed(0)79tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")8081torch.manual_seed(0)82text_encoder = CLIPTextModel(83CLIPTextConfig(84bos_token_id=0,85eos_token_id=2,86hidden_size=embedder_hidden_size,87projection_dim=32,88intermediate_size=37,89layer_norm_eps=1e-05,90num_attention_heads=4,91num_hidden_layers=5,92pad_token_id=1,93vocab_size=1000,94)95)9697torch.manual_seed(0)98unet = UNet2DConditionModel(99sample_size=32,100in_channels=4,101out_channels=4,102down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),103up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"),104block_out_channels=(32, 64),105attention_head_dim=(2, 4),106class_embed_type="projection",107# The class embeddings are the noise augmented image embeddings.108# I.e. the image embeddings concated with the noised embeddings of the same dimension109projection_class_embeddings_input_dim=embedder_projection_dim * 2,110cross_attention_dim=embedder_hidden_size,111layers_per_block=1,112upcast_attention=True,113use_linear_projection=True,114)115116torch.manual_seed(0)117scheduler = DDIMScheduler(118beta_schedule="scaled_linear",119beta_start=0.00085,120beta_end=0.012,121prediction_type="v_prediction",122set_alpha_to_one=False,123steps_offset=1,124)125126torch.manual_seed(0)127vae = AutoencoderKL()128129components = {130# prior components131"prior_tokenizer": prior_tokenizer,132"prior_text_encoder": prior_text_encoder,133"prior": prior,134"prior_scheduler": prior_scheduler,135# image noising components136"image_normalizer": image_normalizer,137"image_noising_scheduler": image_noising_scheduler,138# regular denoising components139"tokenizer": tokenizer,140"text_encoder": text_encoder,141"unet": unet,142"scheduler": scheduler,143"vae": vae,144}145146return components147148def get_dummy_inputs(self, device, seed=0):149if str(device).startswith("mps"):150generator = torch.manual_seed(seed)151else:152generator = torch.Generator(device=device).manual_seed(seed)153inputs = {154"prompt": "A painting of a squirrel eating a burger",155"generator": generator,156"num_inference_steps": 2,157"prior_num_inference_steps": 2,158"output_type": "numpy",159}160return inputs161162# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass163# because UnCLIP GPU undeterminism requires a looser check.164def test_attention_slicing_forward_pass(self):165test_max_difference = torch_device == "cpu"166167self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)168169# Overriding PipelineTesterMixin::test_inference_batch_single_identical170# because UnCLIP undeterminism requires a looser check.171def test_inference_batch_single_identical(self):172test_max_difference = torch_device in ["cpu", "mps"]173174self._test_inference_batch_single_identical(test_max_difference=test_max_difference)175176177@slow178@require_torch_gpu179class StableUnCLIPPipelineIntegrationTests(unittest.TestCase):180def tearDown(self):181# clean up the VRAM after each test182super().tearDown()183gc.collect()184torch.cuda.empty_cache()185186def test_stable_unclip(self):187expected_image = load_numpy(188"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/stable_unclip_2_1_l_anime_turtle_fp16.npy"189)190191pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16)192pipe.to(torch_device)193pipe.set_progress_bar_config(disable=None)194# stable unclip will oom when integration tests are run on a V100,195# so turn on memory savings196pipe.enable_attention_slicing()197pipe.enable_sequential_cpu_offload()198199generator = torch.Generator(device="cpu").manual_seed(0)200output = pipe("anime turle", generator=generator, output_type="np")201202image = output.images[0]203204assert image.shape == (768, 768, 3)205206assert_mean_pixel_difference(image, expected_image)207208def test_stable_unclip_pipeline_with_sequential_cpu_offloading(self):209torch.cuda.empty_cache()210torch.cuda.reset_max_memory_allocated()211torch.cuda.reset_peak_memory_stats()212213pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16)214pipe = pipe.to(torch_device)215pipe.set_progress_bar_config(disable=None)216pipe.enable_attention_slicing()217pipe.enable_sequential_cpu_offload()218219_ = pipe(220"anime turtle",221prior_num_inference_steps=2,222num_inference_steps=2,223output_type="np",224)225226mem_bytes = torch.cuda.max_memory_allocated()227# make sure that less than 7 GB is allocated228assert mem_bytes < 7 * 10**9229230231