Path: blob/main/tests/pipelines/vq_diffusion/test_vq_diffusion.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import unittest1718import numpy as np19import torch20from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2122from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel23from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings24from diffusers.utils import load_numpy, slow, torch_device25from diffusers.utils.testing_utils import require_torch_gpu262728torch.backends.cuda.matmul.allow_tf32 = False293031class VQDiffusionPipelineFastTests(unittest.TestCase):32def tearDown(self):33# clean up the VRAM after each test34super().tearDown()35gc.collect()36torch.cuda.empty_cache()3738@property39def num_embed(self):40return 124142@property43def num_embeds_ada_norm(self):44return 124546@property47def text_embedder_hidden_size(self):48return 324950@property51def dummy_vqvae(self):52torch.manual_seed(0)53model = VQModel(54block_out_channels=[32, 64],55in_channels=3,56out_channels=3,57down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],58up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],59latent_channels=3,60num_vq_embeddings=self.num_embed,61vq_embed_dim=3,62)63return model6465@property66def dummy_tokenizer(self):67tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")68return tokenizer6970@property71def dummy_text_encoder(self):72torch.manual_seed(0)73config = CLIPTextConfig(74bos_token_id=0,75eos_token_id=2,76hidden_size=self.text_embedder_hidden_size,77intermediate_size=37,78layer_norm_eps=1e-05,79num_attention_heads=4,80num_hidden_layers=5,81pad_token_id=1,82vocab_size=1000,83)84return CLIPTextModel(config)8586@property87def dummy_transformer(self):88torch.manual_seed(0)8990height = 1291width = 129293model_kwargs = {94"attention_bias": True,95"cross_attention_dim": 32,96"attention_head_dim": height * width,97"num_attention_heads": 1,98"num_vector_embeds": self.num_embed,99"num_embeds_ada_norm": self.num_embeds_ada_norm,100"norm_num_groups": 32,101"sample_size": width,102"activation_fn": "geglu-approximate",103}104105model = Transformer2DModel(**model_kwargs)106return model107108def test_vq_diffusion(self):109device = "cpu"110111vqvae = self.dummy_vqvae112text_encoder = self.dummy_text_encoder113tokenizer = self.dummy_tokenizer114transformer = self.dummy_transformer115scheduler = VQDiffusionScheduler(self.num_embed)116learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(learnable=False)117118pipe = VQDiffusionPipeline(119vqvae=vqvae,120text_encoder=text_encoder,121tokenizer=tokenizer,122transformer=transformer,123scheduler=scheduler,124learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,125)126pipe = pipe.to(device)127pipe.set_progress_bar_config(disable=None)128129prompt = "teddy bear playing in the pool"130131generator = torch.Generator(device=device).manual_seed(0)132output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")133image = output.images134135generator = torch.Generator(device=device).manual_seed(0)136image_from_tuple = pipe(137[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2138)[0]139140image_slice = image[0, -3:, -3:, -1]141image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]142143assert image.shape == (1, 24, 24, 3)144145expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880])146147assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2148assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2149150def test_vq_diffusion_classifier_free_sampling(self):151device = "cpu"152153vqvae = self.dummy_vqvae154text_encoder = self.dummy_text_encoder155tokenizer = self.dummy_tokenizer156transformer = self.dummy_transformer157scheduler = VQDiffusionScheduler(self.num_embed)158learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(159learnable=True, hidden_size=self.text_embedder_hidden_size, length=tokenizer.model_max_length160)161162pipe = VQDiffusionPipeline(163vqvae=vqvae,164text_encoder=text_encoder,165tokenizer=tokenizer,166transformer=transformer,167scheduler=scheduler,168learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,169)170pipe = pipe.to(device)171pipe.set_progress_bar_config(disable=None)172173prompt = "teddy bear playing in the pool"174175generator = torch.Generator(device=device).manual_seed(0)176output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")177image = output.images178179generator = torch.Generator(device=device).manual_seed(0)180image_from_tuple = pipe(181[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2182)[0]183184image_slice = image[0, -3:, -3:, -1]185image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]186187assert image.shape == (1, 24, 24, 3)188189expected_slice = np.array([0.6647, 0.6531, 0.5303, 0.5891, 0.5726, 0.4439, 0.6304, 0.5564, 0.4912])190191assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2192assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2193194195@slow196@require_torch_gpu197class VQDiffusionPipelineIntegrationTests(unittest.TestCase):198def tearDown(self):199# clean up the VRAM after each test200super().tearDown()201gc.collect()202torch.cuda.empty_cache()203204def test_vq_diffusion_classifier_free_sampling(self):205expected_image = load_numpy(206"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"207"/vq_diffusion/teddy_bear_pool_classifier_free_sampling.npy"208)209210pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq")211pipeline = pipeline.to(torch_device)212pipeline.set_progress_bar_config(disable=None)213214# requires GPU generator for gumbel softmax215# don't use GPU generator in tests though216generator = torch.Generator(device=torch_device).manual_seed(0)217output = pipeline(218"teddy bear playing in the pool",219num_images_per_prompt=1,220generator=generator,221output_type="np",222)223224image = output.images[0]225226assert image.shape == (256, 256, 3)227assert np.abs(expected_image - image).max() < 1e-2228229230