Path: blob/main/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py
1450 views
# coding=utf-81# Copyright 2022 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import unittest1718import numpy as np19import torch2021from diffusers import DDPMScheduler, MidiProcessor, SpectrogramDiffusionPipeline22from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder23from diffusers.utils import require_torch_gpu, skip_mps, slow, torch_device24from diffusers.utils.testing_utils import require_note_seq, require_onnxruntime2526from ...pipeline_params import TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS, TOKENS_TO_AUDIO_GENERATION_PARAMS27from ...test_pipelines_common import PipelineTesterMixin282930torch.backends.cuda.matmul.allow_tf32 = False313233MIDI_FILE = "./tests/fixtures/elise_format0.mid"343536class SpectrogramDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):37pipeline_class = SpectrogramDiffusionPipeline38required_optional_params = PipelineTesterMixin.required_optional_params - {39"callback",40"latents",41"callback_steps",42"output_type",43"num_images_per_prompt",44}45test_attention_slicing = False46test_cpu_offload = False47batch_params = TOKENS_TO_AUDIO_GENERATION_PARAMS48params = TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS4950def get_dummy_components(self):51torch.manual_seed(0)52notes_encoder = SpectrogramNotesEncoder(53max_length=2048,54vocab_size=1536,55d_model=768,56dropout_rate=0.1,57num_layers=1,58num_heads=1,59d_kv=4,60d_ff=2048,61feed_forward_proj="gated-gelu",62)6364continuous_encoder = SpectrogramContEncoder(65input_dims=128,66targets_context_length=256,67d_model=768,68dropout_rate=0.1,69num_layers=1,70num_heads=1,71d_kv=4,72d_ff=2048,73feed_forward_proj="gated-gelu",74)7576decoder = T5FilmDecoder(77input_dims=128,78targets_length=256,79max_decoder_noise_time=20000.0,80d_model=768,81num_layers=1,82num_heads=1,83d_kv=4,84d_ff=2048,85dropout_rate=0.1,86)8788scheduler = DDPMScheduler()8990components = {91"notes_encoder": notes_encoder.eval(),92"continuous_encoder": continuous_encoder.eval(),93"decoder": decoder.eval(),94"scheduler": scheduler,95"melgan": None,96}97return components9899def get_dummy_inputs(self, device, seed=0):100if str(device).startswith("mps"):101generator = torch.manual_seed(seed)102else:103generator = torch.Generator(device=device).manual_seed(seed)104inputs = {105"input_tokens": [106[1134, 90, 1135, 1133, 1080, 112, 1132, 1080, 1133, 1079, 133, 1132, 1079, 1133, 1] + [0] * 2033107],108"generator": generator,109"num_inference_steps": 4,110"output_type": "mel",111}112return inputs113114def test_spectrogram_diffusion(self):115device = "cpu" # ensure determinism for the device-dependent torch.Generator116components = self.get_dummy_components()117pipe = SpectrogramDiffusionPipeline(**components)118pipe = pipe.to(device)119pipe.set_progress_bar_config(disable=None)120121inputs = self.get_dummy_inputs(device)122output = pipe(**inputs)123mel = output.audios124125mel_slice = mel[0, -3:, -3:]126127assert mel_slice.shape == (3, 3)128expected_slice = np.array(129[-11.512925, -4.788215, -0.46172905, -2.051715, -10.539147, -10.970963, -9.091634, 4.0, 4.0]130)131assert np.abs(mel_slice.flatten() - expected_slice).max() < 1e-2132133@skip_mps134def test_save_load_local(self):135return super().test_save_load_local()136137@skip_mps138def test_dict_tuple_outputs_equivalent(self):139return super().test_dict_tuple_outputs_equivalent()140141@skip_mps142def test_save_load_optional_components(self):143return super().test_save_load_optional_components()144145@skip_mps146def test_attention_slicing_forward_pass(self):147return super().test_attention_slicing_forward_pass()148149def test_inference_batch_single_identical(self):150pass151152def test_inference_batch_consistent(self):153pass154155@skip_mps156def test_progress_bar(self):157return super().test_progress_bar()158159160@slow161@require_torch_gpu162@require_onnxruntime163@require_note_seq164class PipelineIntegrationTests(unittest.TestCase):165def tearDown(self):166# clean up the VRAM after each test167super().tearDown()168gc.collect()169torch.cuda.empty_cache()170171def test_callback(self):172# TODO - test that pipeline can decode tokens in a callback173# so that music can be played live174device = torch_device175176pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")177melgan = pipe.melgan178pipe.melgan = None179180pipe = pipe.to(device)181pipe.set_progress_bar_config(disable=None)182183def callback(step, mel_output):184# decode mel to audio185audio = melgan(input_features=mel_output.astype(np.float32))[0]186assert len(audio[0]) == 81920 * (step + 1)187# simulate that audio is played188return audio189190processor = MidiProcessor()191input_tokens = processor(MIDI_FILE)192193input_tokens = input_tokens[:3]194generator = torch.manual_seed(0)195pipe(input_tokens, num_inference_steps=5, generator=generator, callback=callback, output_type="mel")196197def test_spectrogram_fast(self):198device = torch_device199200pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")201pipe = pipe.to(device)202pipe.set_progress_bar_config(disable=None)203processor = MidiProcessor()204205input_tokens = processor(MIDI_FILE)206# just run two denoising loops207input_tokens = input_tokens[:2]208209generator = torch.manual_seed(0)210output = pipe(input_tokens, num_inference_steps=2, generator=generator)211212audio = output.audios[0]213214assert abs(np.abs(audio).sum() - 3612.841) < 1e-1215216def test_spectrogram(self):217device = torch_device218219pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")220pipe = pipe.to(device)221pipe.set_progress_bar_config(disable=None)222223processor = MidiProcessor()224225input_tokens = processor(MIDI_FILE)226227# just run 4 denoising loops228input_tokens = input_tokens[:4]229230generator = torch.manual_seed(0)231output = pipe(input_tokens, num_inference_steps=100, generator=generator)232233audio = output.audios[0]234assert abs(np.abs(audio).sum() - 9389.1111) < 5e-2235236237