Path: blob/main/tests/pipelines/dance_diffusion/test_dance_diffusion.py
1450 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import unittest1718import numpy as np19import torch2021from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel22from diffusers.utils import slow, torch_device23from diffusers.utils.testing_utils import require_torch_gpu, skip_mps2425from ...pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS26from ...test_pipelines_common import PipelineTesterMixin272829torch.backends.cuda.matmul.allow_tf32 = False303132class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):33pipeline_class = DanceDiffusionPipeline34params = UNCONDITIONAL_AUDIO_GENERATION_PARAMS35required_optional_params = PipelineTesterMixin.required_optional_params - {36"callback",37"latents",38"callback_steps",39"output_type",40"num_images_per_prompt",41}42batch_params = UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS43test_attention_slicing = False44test_cpu_offload = False4546def get_dummy_components(self):47torch.manual_seed(0)48unet = UNet1DModel(49block_out_channels=(32, 32, 64),50extra_in_channels=16,51sample_size=512,52sample_rate=16_000,53in_channels=2,54out_channels=2,55flip_sin_to_cos=True,56use_timestep_embedding=False,57time_embedding_type="fourier",58mid_block_type="UNetMidBlock1D",59down_block_types=("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),60up_block_types=("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),61)62scheduler = IPNDMScheduler()6364components = {65"unet": unet,66"scheduler": scheduler,67}68return components6970def get_dummy_inputs(self, device, seed=0):71if str(device).startswith("mps"):72generator = torch.manual_seed(seed)73else:74generator = torch.Generator(device=device).manual_seed(seed)75inputs = {76"batch_size": 1,77"generator": generator,78"num_inference_steps": 4,79}80return inputs8182def test_dance_diffusion(self):83device = "cpu" # ensure determinism for the device-dependent torch.Generator84components = self.get_dummy_components()85pipe = DanceDiffusionPipeline(**components)86pipe = pipe.to(device)87pipe.set_progress_bar_config(disable=None)8889inputs = self.get_dummy_inputs(device)90output = pipe(**inputs)91audio = output.audios9293audio_slice = audio[0, -3:, -3:]9495assert audio.shape == (1, 2, components["unet"].sample_size)96expected_slice = np.array([-0.7265, 1.0000, -0.8388, 0.1175, 0.9498, -1.0000])97assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-29899@skip_mps100def test_save_load_local(self):101return super().test_save_load_local()102103@skip_mps104def test_dict_tuple_outputs_equivalent(self):105return super().test_dict_tuple_outputs_equivalent()106107@skip_mps108def test_save_load_optional_components(self):109return super().test_save_load_optional_components()110111@skip_mps112def test_attention_slicing_forward_pass(self):113return super().test_attention_slicing_forward_pass()114115116@slow117@require_torch_gpu118class PipelineIntegrationTests(unittest.TestCase):119def tearDown(self):120# clean up the VRAM after each test121super().tearDown()122gc.collect()123torch.cuda.empty_cache()124125def test_dance_diffusion(self):126device = torch_device127128pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")129pipe = pipe.to(device)130pipe.set_progress_bar_config(disable=None)131132generator = torch.manual_seed(0)133output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)134audio = output.audios135136audio_slice = audio[0, -3:, -3:]137138assert audio.shape == (1, 2, pipe.unet.sample_size)139expected_slice = np.array([-0.0192, -0.0231, -0.0318, -0.0059, 0.0002, -0.0020])140141assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2142143def test_dance_diffusion_fp16(self):144device = torch_device145146pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)147pipe = pipe.to(device)148pipe.set_progress_bar_config(disable=None)149150generator = torch.manual_seed(0)151output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)152audio = output.audios153154audio_slice = audio[0, -3:, -3:]155156assert audio.shape == (1, 2, pipe.unet.sample_size)157expected_slice = np.array([-0.0367, -0.0488, -0.0771, -0.0525, -0.0444, -0.0341])158159assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2160161162