Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py
1450 views
1
# coding=utf-8
2
# Copyright 2022 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import unittest
18
19
import numpy as np
20
import torch
21
22
from diffusers import DDPMScheduler, MidiProcessor, SpectrogramDiffusionPipeline
23
from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder
24
from diffusers.utils import require_torch_gpu, skip_mps, slow, torch_device
25
from diffusers.utils.testing_utils import require_note_seq, require_onnxruntime
26
27
from ...pipeline_params import TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS, TOKENS_TO_AUDIO_GENERATION_PARAMS
28
from ...test_pipelines_common import PipelineTesterMixin
29
30
31
torch.backends.cuda.matmul.allow_tf32 = False
32
33
34
MIDI_FILE = "./tests/fixtures/elise_format0.mid"
35
36
37
class SpectrogramDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
38
pipeline_class = SpectrogramDiffusionPipeline
39
required_optional_params = PipelineTesterMixin.required_optional_params - {
40
"callback",
41
"latents",
42
"callback_steps",
43
"output_type",
44
"num_images_per_prompt",
45
}
46
test_attention_slicing = False
47
test_cpu_offload = False
48
batch_params = TOKENS_TO_AUDIO_GENERATION_PARAMS
49
params = TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS
50
51
def get_dummy_components(self):
52
torch.manual_seed(0)
53
notes_encoder = SpectrogramNotesEncoder(
54
max_length=2048,
55
vocab_size=1536,
56
d_model=768,
57
dropout_rate=0.1,
58
num_layers=1,
59
num_heads=1,
60
d_kv=4,
61
d_ff=2048,
62
feed_forward_proj="gated-gelu",
63
)
64
65
continuous_encoder = SpectrogramContEncoder(
66
input_dims=128,
67
targets_context_length=256,
68
d_model=768,
69
dropout_rate=0.1,
70
num_layers=1,
71
num_heads=1,
72
d_kv=4,
73
d_ff=2048,
74
feed_forward_proj="gated-gelu",
75
)
76
77
decoder = T5FilmDecoder(
78
input_dims=128,
79
targets_length=256,
80
max_decoder_noise_time=20000.0,
81
d_model=768,
82
num_layers=1,
83
num_heads=1,
84
d_kv=4,
85
d_ff=2048,
86
dropout_rate=0.1,
87
)
88
89
scheduler = DDPMScheduler()
90
91
components = {
92
"notes_encoder": notes_encoder.eval(),
93
"continuous_encoder": continuous_encoder.eval(),
94
"decoder": decoder.eval(),
95
"scheduler": scheduler,
96
"melgan": None,
97
}
98
return components
99
100
def get_dummy_inputs(self, device, seed=0):
101
if str(device).startswith("mps"):
102
generator = torch.manual_seed(seed)
103
else:
104
generator = torch.Generator(device=device).manual_seed(seed)
105
inputs = {
106
"input_tokens": [
107
[1134, 90, 1135, 1133, 1080, 112, 1132, 1080, 1133, 1079, 133, 1132, 1079, 1133, 1] + [0] * 2033
108
],
109
"generator": generator,
110
"num_inference_steps": 4,
111
"output_type": "mel",
112
}
113
return inputs
114
115
def test_spectrogram_diffusion(self):
116
device = "cpu" # ensure determinism for the device-dependent torch.Generator
117
components = self.get_dummy_components()
118
pipe = SpectrogramDiffusionPipeline(**components)
119
pipe = pipe.to(device)
120
pipe.set_progress_bar_config(disable=None)
121
122
inputs = self.get_dummy_inputs(device)
123
output = pipe(**inputs)
124
mel = output.audios
125
126
mel_slice = mel[0, -3:, -3:]
127
128
assert mel_slice.shape == (3, 3)
129
expected_slice = np.array(
130
[-11.512925, -4.788215, -0.46172905, -2.051715, -10.539147, -10.970963, -9.091634, 4.0, 4.0]
131
)
132
assert np.abs(mel_slice.flatten() - expected_slice).max() < 1e-2
133
134
@skip_mps
135
def test_save_load_local(self):
136
return super().test_save_load_local()
137
138
@skip_mps
139
def test_dict_tuple_outputs_equivalent(self):
140
return super().test_dict_tuple_outputs_equivalent()
141
142
@skip_mps
143
def test_save_load_optional_components(self):
144
return super().test_save_load_optional_components()
145
146
@skip_mps
147
def test_attention_slicing_forward_pass(self):
148
return super().test_attention_slicing_forward_pass()
149
150
def test_inference_batch_single_identical(self):
151
pass
152
153
def test_inference_batch_consistent(self):
154
pass
155
156
@skip_mps
157
def test_progress_bar(self):
158
return super().test_progress_bar()
159
160
161
@slow
162
@require_torch_gpu
163
@require_onnxruntime
164
@require_note_seq
165
class PipelineIntegrationTests(unittest.TestCase):
166
def tearDown(self):
167
# clean up the VRAM after each test
168
super().tearDown()
169
gc.collect()
170
torch.cuda.empty_cache()
171
172
def test_callback(self):
173
# TODO - test that pipeline can decode tokens in a callback
174
# so that music can be played live
175
device = torch_device
176
177
pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")
178
melgan = pipe.melgan
179
pipe.melgan = None
180
181
pipe = pipe.to(device)
182
pipe.set_progress_bar_config(disable=None)
183
184
def callback(step, mel_output):
185
# decode mel to audio
186
audio = melgan(input_features=mel_output.astype(np.float32))[0]
187
assert len(audio[0]) == 81920 * (step + 1)
188
# simulate that audio is played
189
return audio
190
191
processor = MidiProcessor()
192
input_tokens = processor(MIDI_FILE)
193
194
input_tokens = input_tokens[:3]
195
generator = torch.manual_seed(0)
196
pipe(input_tokens, num_inference_steps=5, generator=generator, callback=callback, output_type="mel")
197
198
def test_spectrogram_fast(self):
199
device = torch_device
200
201
pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")
202
pipe = pipe.to(device)
203
pipe.set_progress_bar_config(disable=None)
204
processor = MidiProcessor()
205
206
input_tokens = processor(MIDI_FILE)
207
# just run two denoising loops
208
input_tokens = input_tokens[:2]
209
210
generator = torch.manual_seed(0)
211
output = pipe(input_tokens, num_inference_steps=2, generator=generator)
212
213
audio = output.audios[0]
214
215
assert abs(np.abs(audio).sum() - 3612.841) < 1e-1
216
217
def test_spectrogram(self):
218
device = torch_device
219
220
pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")
221
pipe = pipe.to(device)
222
pipe.set_progress_bar_config(disable=None)
223
224
processor = MidiProcessor()
225
226
input_tokens = processor(MIDI_FILE)
227
228
# just run 4 denoising loops
229
input_tokens = input_tokens[:4]
230
231
generator = torch.manual_seed(0)
232
output = pipe(input_tokens, num_inference_steps=100, generator=generator)
233
234
audio = output.audios[0]
235
assert abs(np.abs(audio).sum() - 9389.1111) < 5e-2
236
237