Path: blob/main/tests/pipelines/audioldm/test_audioldm.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.141516import gc17import unittest1819import numpy as np20import torch21import torch.nn.functional as F22from transformers import (23ClapTextConfig,24ClapTextModelWithProjection,25RobertaTokenizer,26SpeechT5HifiGan,27SpeechT5HifiGanConfig,28)2930from diffusers import (31AudioLDMPipeline,32AutoencoderKL,33DDIMScheduler,34LMSDiscreteScheduler,35PNDMScheduler,36UNet2DConditionModel,37)38from diffusers.utils import slow, torch_device3940from ...pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS41from ...test_pipelines_common import PipelineTesterMixin424344class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):45pipeline_class = AudioLDMPipeline46params = TEXT_TO_AUDIO_PARAMS47batch_params = TEXT_TO_AUDIO_BATCH_PARAMS48required_optional_params = frozenset(49[50"num_inference_steps",51"num_waveforms_per_prompt",52"generator",53"latents",54"output_type",55"return_dict",56"callback",57"callback_steps",58]59)6061def get_dummy_components(self):62torch.manual_seed(0)63unet = UNet2DConditionModel(64block_out_channels=(32, 64),65layers_per_block=2,66sample_size=32,67in_channels=4,68out_channels=4,69down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),70up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),71cross_attention_dim=(32, 64),72class_embed_type="simple_projection",73projection_class_embeddings_input_dim=32,74class_embeddings_concat=True,75)76scheduler = DDIMScheduler(77beta_start=0.00085,78beta_end=0.012,79beta_schedule="scaled_linear",80clip_sample=False,81set_alpha_to_one=False,82)83torch.manual_seed(0)84vae = AutoencoderKL(85block_out_channels=[32, 64],86in_channels=1,87out_channels=1,88down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],89up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],90latent_channels=4,91)92torch.manual_seed(0)93text_encoder_config = ClapTextConfig(94bos_token_id=0,95eos_token_id=2,96hidden_size=32,97intermediate_size=37,98layer_norm_eps=1e-05,99num_attention_heads=4,100num_hidden_layers=5,101pad_token_id=1,102vocab_size=1000,103projection_dim=32,104)105text_encoder = ClapTextModelWithProjection(text_encoder_config)106tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)107108vocoder_config = SpeechT5HifiGanConfig(109model_in_dim=8,110sampling_rate=16000,111upsample_initial_channel=16,112upsample_rates=[2, 2],113upsample_kernel_sizes=[4, 4],114resblock_kernel_sizes=[3, 7],115resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],116normalize_before=False,117)118119vocoder = SpeechT5HifiGan(vocoder_config)120121components = {122"unet": unet,123"scheduler": scheduler,124"vae": vae,125"text_encoder": text_encoder,126"tokenizer": tokenizer,127"vocoder": vocoder,128}129return components130131def get_dummy_inputs(self, device, seed=0):132if str(device).startswith("mps"):133generator = torch.manual_seed(seed)134else:135generator = torch.Generator(device=device).manual_seed(seed)136inputs = {137"prompt": "A hammer hitting a wooden surface",138"generator": generator,139"num_inference_steps": 2,140"guidance_scale": 6.0,141}142return inputs143144def test_audioldm_ddim(self):145device = "cpu" # ensure determinism for the device-dependent torch.Generator146147components = self.get_dummy_components()148audioldm_pipe = AudioLDMPipeline(**components)149audioldm_pipe = audioldm_pipe.to(torch_device)150audioldm_pipe.set_progress_bar_config(disable=None)151152inputs = self.get_dummy_inputs(device)153output = audioldm_pipe(**inputs)154audio = output.audios[0]155156assert audio.ndim == 1157assert len(audio) == 256158159audio_slice = audio[:10]160expected_slice = np.array(161[-0.0050, 0.0050, -0.0060, 0.0033, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0033]162)163164assert np.abs(audio_slice - expected_slice).max() < 1e-2165166def test_audioldm_prompt_embeds(self):167components = self.get_dummy_components()168audioldm_pipe = AudioLDMPipeline(**components)169audioldm_pipe = audioldm_pipe.to(torch_device)170audioldm_pipe = audioldm_pipe.to(torch_device)171audioldm_pipe.set_progress_bar_config(disable=None)172173inputs = self.get_dummy_inputs(torch_device)174inputs["prompt"] = 3 * [inputs["prompt"]]175176# forward177output = audioldm_pipe(**inputs)178audio_1 = output.audios[0]179180inputs = self.get_dummy_inputs(torch_device)181prompt = 3 * [inputs.pop("prompt")]182183text_inputs = audioldm_pipe.tokenizer(184prompt,185padding="max_length",186max_length=audioldm_pipe.tokenizer.model_max_length,187truncation=True,188return_tensors="pt",189)190text_inputs = text_inputs["input_ids"].to(torch_device)191192prompt_embeds = audioldm_pipe.text_encoder(193text_inputs,194)195prompt_embeds = prompt_embeds.text_embeds196# additional L_2 normalization over each hidden-state197prompt_embeds = F.normalize(prompt_embeds, dim=-1)198199inputs["prompt_embeds"] = prompt_embeds200201# forward202output = audioldm_pipe(**inputs)203audio_2 = output.audios[0]204205assert np.abs(audio_1 - audio_2).max() < 1e-2206207def test_audioldm_negative_prompt_embeds(self):208components = self.get_dummy_components()209audioldm_pipe = AudioLDMPipeline(**components)210audioldm_pipe = audioldm_pipe.to(torch_device)211audioldm_pipe = audioldm_pipe.to(torch_device)212audioldm_pipe.set_progress_bar_config(disable=None)213214inputs = self.get_dummy_inputs(torch_device)215negative_prompt = 3 * ["this is a negative prompt"]216inputs["negative_prompt"] = negative_prompt217inputs["prompt"] = 3 * [inputs["prompt"]]218219# forward220output = audioldm_pipe(**inputs)221audio_1 = output.audios[0]222223inputs = self.get_dummy_inputs(torch_device)224prompt = 3 * [inputs.pop("prompt")]225226embeds = []227for p in [prompt, negative_prompt]:228text_inputs = audioldm_pipe.tokenizer(229p,230padding="max_length",231max_length=audioldm_pipe.tokenizer.model_max_length,232truncation=True,233return_tensors="pt",234)235text_inputs = text_inputs["input_ids"].to(torch_device)236237text_embeds = audioldm_pipe.text_encoder(238text_inputs,239)240text_embeds = text_embeds.text_embeds241# additional L_2 normalization over each hidden-state242text_embeds = F.normalize(text_embeds, dim=-1)243244embeds.append(text_embeds)245246inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds247248# forward249output = audioldm_pipe(**inputs)250audio_2 = output.audios[0]251252assert np.abs(audio_1 - audio_2).max() < 1e-2253254def test_audioldm_negative_prompt(self):255device = "cpu" # ensure determinism for the device-dependent torch.Generator256components = self.get_dummy_components()257components["scheduler"] = PNDMScheduler(skip_prk_steps=True)258audioldm_pipe = AudioLDMPipeline(**components)259audioldm_pipe = audioldm_pipe.to(device)260audioldm_pipe.set_progress_bar_config(disable=None)261262inputs = self.get_dummy_inputs(device)263negative_prompt = "egg cracking"264output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)265audio = output.audios[0]266267assert audio.ndim == 1268assert len(audio) == 256269270audio_slice = audio[:10]271expected_slice = np.array(272[-0.0051, 0.0050, -0.0060, 0.0034, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0032]273)274275assert np.abs(audio_slice - expected_slice).max() < 1e-2276277def test_audioldm_num_waveforms_per_prompt(self):278device = "cpu" # ensure determinism for the device-dependent torch.Generator279components = self.get_dummy_components()280components["scheduler"] = PNDMScheduler(skip_prk_steps=True)281audioldm_pipe = AudioLDMPipeline(**components)282audioldm_pipe = audioldm_pipe.to(device)283audioldm_pipe.set_progress_bar_config(disable=None)284285prompt = "A hammer hitting a wooden surface"286287# test num_waveforms_per_prompt=1 (default)288audios = audioldm_pipe(prompt, num_inference_steps=2).audios289290assert audios.shape == (1, 256)291292# test num_waveforms_per_prompt=1 (default) for batch of prompts293batch_size = 2294audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios295296assert audios.shape == (batch_size, 256)297298# test num_waveforms_per_prompt for single prompt299num_waveforms_per_prompt = 2300audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios301302assert audios.shape == (num_waveforms_per_prompt, 256)303304# test num_waveforms_per_prompt for batch of prompts305batch_size = 2306audios = audioldm_pipe(307[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt308).audios309310assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)311312def test_audioldm_audio_length_in_s(self):313device = "cpu" # ensure determinism for the device-dependent torch.Generator314components = self.get_dummy_components()315audioldm_pipe = AudioLDMPipeline(**components)316audioldm_pipe = audioldm_pipe.to(torch_device)317audioldm_pipe.set_progress_bar_config(disable=None)318vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate319320inputs = self.get_dummy_inputs(device)321output = audioldm_pipe(audio_length_in_s=0.016, **inputs)322audio = output.audios[0]323324assert audio.ndim == 1325assert len(audio) / vocoder_sampling_rate == 0.016326327output = audioldm_pipe(audio_length_in_s=0.032, **inputs)328audio = output.audios[0]329330assert audio.ndim == 1331assert len(audio) / vocoder_sampling_rate == 0.032332333def test_audioldm_vocoder_model_in_dim(self):334components = self.get_dummy_components()335audioldm_pipe = AudioLDMPipeline(**components)336audioldm_pipe = audioldm_pipe.to(torch_device)337audioldm_pipe.set_progress_bar_config(disable=None)338339prompt = ["hey"]340341output = audioldm_pipe(prompt, num_inference_steps=1)342audio_shape = output.audios.shape343assert audio_shape == (1, 256)344345config = audioldm_pipe.vocoder.config346config.model_in_dim *= 2347audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)348output = audioldm_pipe(prompt, num_inference_steps=1)349audio_shape = output.audios.shape350# waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram351assert audio_shape == (1, 256)352353def test_attention_slicing_forward_pass(self):354self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)355356def test_inference_batch_single_identical(self):357self._test_inference_batch_single_identical(test_mean_pixel_difference=False)358359360@slow361# @require_torch_gpu362class AudioLDMPipelineSlowTests(unittest.TestCase):363def tearDown(self):364super().tearDown()365gc.collect()366torch.cuda.empty_cache()367368def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):369generator = torch.Generator(device=generator_device).manual_seed(seed)370latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))371latents = torch.from_numpy(latents).to(device=device, dtype=dtype)372inputs = {373"prompt": "A hammer hitting a wooden surface",374"latents": latents,375"generator": generator,376"num_inference_steps": 3,377"guidance_scale": 2.5,378}379return inputs380381def test_audioldm(self):382audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")383audioldm_pipe = audioldm_pipe.to(torch_device)384audioldm_pipe.set_progress_bar_config(disable=None)385386inputs = self.get_inputs(torch_device)387inputs["num_inference_steps"] = 25388audio = audioldm_pipe(**inputs).audios[0]389390assert audio.ndim == 1391assert len(audio) == 81920392393audio_slice = audio[77230:77240]394expected_slice = np.array(395[-0.4884, -0.4607, 0.0023, 0.5007, 0.5896, 0.5151, 0.3813, -0.0208, -0.3687, -0.4315]396)397max_diff = np.abs(expected_slice - audio_slice).max()398assert max_diff < 1e-2399400def test_audioldm_lms(self):401audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")402audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)403audioldm_pipe = audioldm_pipe.to(torch_device)404audioldm_pipe.set_progress_bar_config(disable=None)405406inputs = self.get_inputs(torch_device)407audio = audioldm_pipe(**inputs).audios[0]408409assert audio.ndim == 1410assert len(audio) == 81920411412audio_slice = audio[27780:27790]413expected_slice = np.array([-0.2131, -0.0873, -0.0124, -0.0189, 0.0569, 0.1373, 0.1883, 0.2886, 0.3297, 0.2212])414max_diff = np.abs(expected_slice - audio_slice).max()415assert max_diff < 1e-2416417418