Path: blob/main/tests/pipelines/text_to_video/test_text_to_video.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import unittest1617import numpy as np18import torch19from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2021from diffusers import (22AutoencoderKL,23DDIMScheduler,24DPMSolverMultistepScheduler,25TextToVideoSDPipeline,26UNet3DConditionModel,27)28from diffusers.utils import load_numpy, skip_mps, slow2930from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS31from ...test_pipelines_common import PipelineTesterMixin323334torch.backends.cuda.matmul.allow_tf32 = False353637@skip_mps38class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):39pipeline_class = TextToVideoSDPipeline40params = TEXT_TO_IMAGE_PARAMS41batch_params = TEXT_TO_IMAGE_BATCH_PARAMS42# No `output_type`.43required_optional_params = frozenset(44[45"num_inference_steps",46"generator",47"latents",48"return_dict",49"callback",50"callback_steps",51]52)5354def get_dummy_components(self):55torch.manual_seed(0)56unet = UNet3DConditionModel(57block_out_channels=(32, 64, 64, 64),58layers_per_block=2,59sample_size=32,60in_channels=4,61out_channels=4,62down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"),63up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),64cross_attention_dim=32,65attention_head_dim=4,66)67scheduler = DDIMScheduler(68beta_start=0.00085,69beta_end=0.012,70beta_schedule="scaled_linear",71clip_sample=False,72set_alpha_to_one=False,73)74torch.manual_seed(0)75vae = AutoencoderKL(76block_out_channels=[32, 64],77in_channels=3,78out_channels=3,79down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],80up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],81latent_channels=4,82sample_size=128,83)84torch.manual_seed(0)85text_encoder_config = CLIPTextConfig(86bos_token_id=0,87eos_token_id=2,88hidden_size=32,89intermediate_size=37,90layer_norm_eps=1e-05,91num_attention_heads=4,92num_hidden_layers=5,93pad_token_id=1,94vocab_size=1000,95hidden_act="gelu",96projection_dim=512,97)98text_encoder = CLIPTextModel(text_encoder_config)99tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")100101components = {102"unet": unet,103"scheduler": scheduler,104"vae": vae,105"text_encoder": text_encoder,106"tokenizer": tokenizer,107}108return components109110def get_dummy_inputs(self, device, seed=0):111if str(device).startswith("mps"):112generator = torch.manual_seed(seed)113else:114generator = torch.Generator(device=device).manual_seed(seed)115inputs = {116"prompt": "A painting of a squirrel eating a burger",117"generator": generator,118"num_inference_steps": 2,119"guidance_scale": 6.0,120"output_type": "pt",121}122return inputs123124def test_text_to_video_default_case(self):125device = "cpu" # ensure determinism for the device-dependent torch.Generator126components = self.get_dummy_components()127sd_pipe = TextToVideoSDPipeline(**components)128sd_pipe = sd_pipe.to(device)129sd_pipe.set_progress_bar_config(disable=None)130131inputs = self.get_dummy_inputs(device)132inputs["output_type"] = "np"133frames = sd_pipe(**inputs).frames134image_slice = frames[0][-3:, -3:, -1]135136assert frames[0].shape == (64, 64, 3)137expected_slice = np.array([166, 184, 167, 118, 102, 123, 108, 93, 114])138139assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2140141def test_attention_slicing_forward_pass(self):142self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)143144# (todo): sayakpaul145@unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")146def test_inference_batch_consistent(self):147pass148149# (todo): sayakpaul150@unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")151def test_inference_batch_single_identical(self):152pass153154@unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")155def test_num_images_per_prompt(self):156pass157158def test_progress_bar(self):159return super().test_progress_bar()160161162@slow163@skip_mps164class TextToVideoSDPipelineSlowTests(unittest.TestCase):165def test_full_model(self):166expected_video = load_numpy(167"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video.npy"168)169170pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")171pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)172pipe = pipe.to("cuda")173174prompt = "Spiderman is surfing"175generator = torch.Generator(device="cpu").manual_seed(0)176177video_frames = pipe(prompt, generator=generator, num_inference_steps=25, output_type="pt").frames178video = video_frames.cpu().numpy()179180assert np.abs(expected_video - video).mean() < 5e-2181182def test_two_step_model(self):183expected_video = load_numpy(184"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_to_video/video_2step.npy"185)186187pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")188pipe = pipe.to("cuda")189190prompt = "Spiderman is surfing"191generator = torch.Generator(device="cpu").manual_seed(0)192193video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="pt").frames194video = video_frames.cpu().numpy()195196assert np.abs(expected_video - video).mean() < 5e-2197198199