Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/multiband_melgan/train_multiband_melgan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Train Multi-Band MelGAN."""
16
17
import tensorflow as tf
18
19
physical_devices = tf.config.list_physical_devices("GPU")
20
for i in range(len(physical_devices)):
21
tf.config.experimental.set_memory_growth(physical_devices[i], True)
22
23
import sys
24
25
sys.path.append(".")
26
27
import argparse
28
import logging
29
import os
30
31
import numpy as np
32
import soundfile as sf
33
import yaml
34
from tensorflow.keras.mixed_precision import experimental as mixed_precision
35
36
import tensorflow_tts
37
from examples.melgan.audio_mel_dataset import AudioMelDataset
38
from examples.melgan.train_melgan import MelganTrainer, collater
39
from tensorflow_tts.configs import (
40
MultiBandMelGANDiscriminatorConfig,
41
MultiBandMelGANGeneratorConfig,
42
)
43
from tensorflow_tts.losses import TFMultiResolutionSTFT
44
from tensorflow_tts.models import (
45
TFPQMF,
46
TFMelGANGenerator,
47
TFMelGANMultiScaleDiscriminator,
48
)
49
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
50
51
52
class MultiBandMelganTrainer(MelganTrainer):
53
"""Multi-Band MelGAN Trainer class based on MelganTrainer."""
54
55
def __init__(
56
self,
57
config,
58
strategy,
59
steps=0,
60
epochs=0,
61
is_generator_mixed_precision=False,
62
is_discriminator_mixed_precision=False,
63
):
64
"""Initialize trainer.
65
66
Args:
67
steps (int): Initial global steps.
68
epochs (int): Initial global epochs.
69
config (dict): Config dict loaded from yaml format configuration file.
70
is_generator_mixed_precision (bool): Use mixed precision for generator or not.
71
is_discriminator_mixed_precision (bool): Use mixed precision for discriminator or not.
72
73
"""
74
super(MultiBandMelganTrainer, self).__init__(
75
config=config,
76
steps=steps,
77
epochs=epochs,
78
strategy=strategy,
79
is_generator_mixed_precision=is_generator_mixed_precision,
80
is_discriminator_mixed_precision=is_discriminator_mixed_precision,
81
)
82
83
# define metrics to aggregates data and use tf.summary logs them
84
self.list_metrics_name = [
85
"adversarial_loss",
86
"subband_spectral_convergence_loss",
87
"subband_log_magnitude_loss",
88
"fullband_spectral_convergence_loss",
89
"fullband_log_magnitude_loss",
90
"gen_loss",
91
"real_loss",
92
"fake_loss",
93
"dis_loss",
94
]
95
96
self.init_train_eval_metrics(self.list_metrics_name)
97
self.reset_states_train()
98
self.reset_states_eval()
99
100
def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer, pqmf):
101
super().compile(gen_model, dis_model, gen_optimizer, dis_optimizer)
102
# define loss
103
self.sub_band_stft_loss = TFMultiResolutionSTFT(
104
**self.config["subband_stft_loss_params"]
105
)
106
self.full_band_stft_loss = TFMultiResolutionSTFT(
107
**self.config["stft_loss_params"]
108
)
109
110
# define pqmf module
111
self.pqmf = pqmf
112
113
def compute_per_example_generator_losses(self, batch, outputs):
114
"""Compute per example generator losses and return dict_metrics_losses
115
Note that all element of the loss MUST has a shape [batch_size] and
116
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
117
118
Args:
119
batch: dictionary batch input return from dataloader
120
outputs: outputs of the model
121
122
Returns:
123
per_example_losses: per example losses for each GPU, shape [B]
124
dict_metrics_losses: dictionary loss.
125
"""
126
dict_metrics_losses = {}
127
per_example_losses = 0.0
128
129
audios = batch["audios"]
130
y_mb_hat = outputs
131
y_hat = self.pqmf.synthesis(y_mb_hat)
132
133
y_mb = self.pqmf.analysis(tf.expand_dims(audios, -1))
134
y_mb = tf.transpose(y_mb, (0, 2, 1)) # [B, subbands, T//subbands]
135
y_mb = tf.reshape(y_mb, (-1, tf.shape(y_mb)[-1])) # [B * subbands, T']
136
137
y_mb_hat = tf.transpose(y_mb_hat, (0, 2, 1)) # [B, subbands, T//subbands]
138
y_mb_hat = tf.reshape(
139
y_mb_hat, (-1, tf.shape(y_mb_hat)[-1])
140
) # [B * subbands, T']
141
142
# calculate sub/full band spectral_convergence and log mag loss.
143
sub_sc_loss, sub_mag_loss = calculate_2d_loss(
144
y_mb, y_mb_hat, self.sub_band_stft_loss
145
)
146
sub_sc_loss = tf.reduce_mean(
147
tf.reshape(sub_sc_loss, [-1, self.pqmf.subbands]), -1
148
)
149
sub_mag_loss = tf.reduce_mean(
150
tf.reshape(sub_mag_loss, [-1, self.pqmf.subbands]), -1
151
)
152
full_sc_loss, full_mag_loss = calculate_2d_loss(
153
audios, tf.squeeze(y_hat, -1), self.full_band_stft_loss
154
)
155
156
# define generator loss
157
gen_loss = 0.5 * (sub_sc_loss + sub_mag_loss) + 0.5 * (
158
full_sc_loss + full_mag_loss
159
)
160
161
if self.steps >= self.config["discriminator_train_start_steps"]:
162
p_hat = self._discriminator(y_hat)
163
p = self._discriminator(tf.expand_dims(audios, 2))
164
adv_loss = 0.0
165
for i in range(len(p_hat)):
166
adv_loss += calculate_3d_loss(
167
tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss
168
)
169
adv_loss /= i + 1
170
gen_loss += self.config["lambda_adv"] * adv_loss
171
172
dict_metrics_losses.update({"adversarial_loss": adv_loss},)
173
174
dict_metrics_losses.update({"gen_loss": gen_loss})
175
dict_metrics_losses.update({"subband_spectral_convergence_loss": sub_sc_loss})
176
dict_metrics_losses.update({"subband_log_magnitude_loss": sub_mag_loss})
177
dict_metrics_losses.update({"fullband_spectral_convergence_loss": full_sc_loss})
178
dict_metrics_losses.update({"fullband_log_magnitude_loss": full_mag_loss})
179
180
per_example_losses = gen_loss
181
return per_example_losses, dict_metrics_losses
182
183
def compute_per_example_discriminator_losses(self, batch, gen_outputs):
184
"""Compute per example discriminator losses and return dict_metrics_losses
185
Note that all element of the loss MUST has a shape [batch_size] and
186
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
187
188
Args:
189
batch: dictionary batch input return from dataloader
190
outputs: outputs of the model
191
192
Returns:
193
per_example_losses: per example losses for each GPU, shape [B]
194
dict_metrics_losses: dictionary loss.
195
"""
196
y_mb_hat = gen_outputs
197
y_hat = self.pqmf.synthesis(y_mb_hat)
198
(
199
per_example_losses,
200
dict_metrics_losses,
201
) = super().compute_per_example_discriminator_losses(batch, y_hat)
202
return per_example_losses, dict_metrics_losses
203
204
def generate_and_save_intermediate_result(self, batch):
205
"""Generate and save intermediate result."""
206
import matplotlib.pyplot as plt
207
208
y_mb_batch_ = self.one_step_predict(batch) # [B, T // subbands, subbands]
209
y_batch = batch["audios"]
210
utt_ids = batch["utt_ids"]
211
212
# convert to tensor.
213
# here we just take a sample at first replica.
214
try:
215
y_mb_batch_ = y_mb_batch_.values[0].numpy()
216
y_batch = y_batch.values[0].numpy()
217
utt_ids = utt_ids.values[0].numpy()
218
except Exception:
219
y_mb_batch_ = y_mb_batch_.numpy()
220
y_batch = y_batch.numpy()
221
utt_ids = utt_ids.numpy()
222
223
y_batch_ = self.pqmf.synthesis(y_mb_batch_).numpy() # [B, T, 1]
224
225
# check directory
226
dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")
227
if not os.path.exists(dirname):
228
os.makedirs(dirname)
229
230
for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 0):
231
# convert to ndarray
232
y, y_ = tf.reshape(y, [-1]).numpy(), tf.reshape(y_, [-1]).numpy()
233
234
# plit figure and save it
235
utt_id = utt_ids[idx]
236
figname = os.path.join(dirname, f"{utt_id}.png")
237
plt.subplot(2, 1, 1)
238
plt.plot(y)
239
plt.title("groundtruth speech")
240
plt.subplot(2, 1, 2)
241
plt.plot(y_)
242
plt.title(f"generated speech @ {self.steps} steps")
243
plt.tight_layout()
244
plt.savefig(figname)
245
plt.close()
246
247
# save as wavefile
248
y = np.clip(y, -1, 1)
249
y_ = np.clip(y_, -1, 1)
250
sf.write(
251
figname.replace(".png", "_ref.wav"),
252
y,
253
self.config["sampling_rate"],
254
"PCM_16",
255
)
256
sf.write(
257
figname.replace(".png", "_gen.wav"),
258
y_,
259
self.config["sampling_rate"],
260
"PCM_16",
261
)
262
263
264
def main():
265
"""Run training process."""
266
parser = argparse.ArgumentParser(
267
description="Train MultiBand MelGAN (See detail in examples/multiband_melgan/train_multiband_melgan.py)"
268
)
269
parser.add_argument(
270
"--train-dir",
271
default=None,
272
type=str,
273
help="directory including training data. ",
274
)
275
parser.add_argument(
276
"--dev-dir",
277
default=None,
278
type=str,
279
help="directory including development data. ",
280
)
281
parser.add_argument(
282
"--use-norm", default=1, type=int, help="use norm mels for training or raw."
283
)
284
parser.add_argument(
285
"--outdir", type=str, required=True, help="directory to save checkpoints."
286
)
287
parser.add_argument(
288
"--config", type=str, required=True, help="yaml format configuration file."
289
)
290
parser.add_argument(
291
"--resume",
292
default="",
293
type=str,
294
nargs="?",
295
help='checkpoint file path to resume training. (default="")',
296
)
297
parser.add_argument(
298
"--verbose",
299
type=int,
300
default=1,
301
help="logging level. higher is more logging. (default=1)",
302
)
303
parser.add_argument(
304
"--generator_mixed_precision",
305
default=0,
306
type=int,
307
help="using mixed precision for generator or not.",
308
)
309
parser.add_argument(
310
"--discriminator_mixed_precision",
311
default=0,
312
type=int,
313
help="using mixed precision for discriminator or not.",
314
)
315
parser.add_argument(
316
"--pretrained",
317
default="",
318
type=str,
319
nargs="?",
320
help="path of .h5 mb-melgan generator to load weights from",
321
)
322
args = parser.parse_args()
323
324
# return strategy
325
STRATEGY = return_strategy()
326
327
# set mixed precision config
328
if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:
329
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
330
331
args.generator_mixed_precision = bool(args.generator_mixed_precision)
332
args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)
333
334
args.use_norm = bool(args.use_norm)
335
336
# set logger
337
if args.verbose > 1:
338
logging.basicConfig(
339
level=logging.DEBUG,
340
stream=sys.stdout,
341
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
342
)
343
elif args.verbose > 0:
344
logging.basicConfig(
345
level=logging.INFO,
346
stream=sys.stdout,
347
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
348
)
349
else:
350
logging.basicConfig(
351
level=logging.WARN,
352
stream=sys.stdout,
353
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
354
)
355
logging.warning("Skip DEBUG/INFO messages")
356
357
# check directory existence
358
if not os.path.exists(args.outdir):
359
os.makedirs(args.outdir)
360
361
# check arguments
362
if args.train_dir is None:
363
raise ValueError("Please specify --train-dir")
364
if args.dev_dir is None:
365
raise ValueError("Please specify either --valid-dir")
366
367
# load and save config
368
with open(args.config) as f:
369
config = yaml.load(f, Loader=yaml.Loader)
370
config.update(vars(args))
371
config["version"] = tensorflow_tts.__version__
372
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
373
yaml.dump(config, f, Dumper=yaml.Dumper)
374
for key, value in config.items():
375
logging.info(f"{key} = {value}")
376
377
# get dataset
378
if config["remove_short_samples"]:
379
mel_length_threshold = config["batch_max_steps"] // config[
380
"hop_size"
381
] + 2 * config["multiband_melgan_generator_params"].get("aux_context_window", 0)
382
else:
383
mel_length_threshold = None
384
385
if config["format"] == "npy":
386
audio_query = "*-wave.npy"
387
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
388
audio_load_fn = np.load
389
mel_load_fn = np.load
390
else:
391
raise ValueError("Only npy are supported.")
392
393
# define train/valid dataset
394
train_dataset = AudioMelDataset(
395
root_dir=args.train_dir,
396
audio_query=audio_query,
397
mel_query=mel_query,
398
audio_load_fn=audio_load_fn,
399
mel_load_fn=mel_load_fn,
400
mel_length_threshold=mel_length_threshold,
401
).create(
402
is_shuffle=config["is_shuffle"],
403
map_fn=lambda items: collater(
404
items,
405
batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),
406
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
407
),
408
allow_cache=config["allow_cache"],
409
batch_size=config["batch_size"]
410
* STRATEGY.num_replicas_in_sync
411
* config["gradient_accumulation_steps"],
412
)
413
414
valid_dataset = AudioMelDataset(
415
root_dir=args.dev_dir,
416
audio_query=audio_query,
417
mel_query=mel_query,
418
audio_load_fn=audio_load_fn,
419
mel_load_fn=mel_load_fn,
420
mel_length_threshold=mel_length_threshold,
421
).create(
422
is_shuffle=config["is_shuffle"],
423
map_fn=lambda items: collater(
424
items,
425
batch_max_steps=tf.constant(
426
config["batch_max_steps_valid"], dtype=tf.int32
427
),
428
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
429
),
430
allow_cache=config["allow_cache"],
431
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
432
)
433
434
# define trainer
435
trainer = MultiBandMelganTrainer(
436
steps=0,
437
epochs=0,
438
config=config,
439
strategy=STRATEGY,
440
is_generator_mixed_precision=args.generator_mixed_precision,
441
is_discriminator_mixed_precision=args.discriminator_mixed_precision,
442
)
443
444
with STRATEGY.scope():
445
# define generator and discriminator
446
generator = TFMelGANGenerator(
447
MultiBandMelGANGeneratorConfig(
448
**config["multiband_melgan_generator_params"]
449
),
450
name="multi_band_melgan_generator",
451
)
452
453
discriminator = TFMelGANMultiScaleDiscriminator(
454
MultiBandMelGANDiscriminatorConfig(
455
**config["multiband_melgan_discriminator_params"]
456
),
457
name="multi_band_melgan_discriminator",
458
)
459
460
pqmf = TFPQMF(
461
MultiBandMelGANGeneratorConfig(
462
**config["multiband_melgan_generator_params"]
463
),
464
dtype=tf.float32,
465
name="pqmf",
466
)
467
468
# dummy input to build model.
469
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
470
y_mb_hat = generator(fake_mels)
471
y_hat = pqmf.synthesis(y_mb_hat)
472
discriminator(y_hat)
473
474
if len(args.pretrained) > 1:
475
generator.load_weights(args.pretrained)
476
logging.info(
477
f"Successfully loaded pretrained weight from {args.pretrained}."
478
)
479
480
generator.summary()
481
discriminator.summary()
482
483
# define optimizer
484
generator_lr_fn = getattr(
485
tf.keras.optimizers.schedules, config["generator_optimizer_params"]["lr_fn"]
486
)(**config["generator_optimizer_params"]["lr_params"])
487
discriminator_lr_fn = getattr(
488
tf.keras.optimizers.schedules,
489
config["discriminator_optimizer_params"]["lr_fn"],
490
)(**config["discriminator_optimizer_params"]["lr_params"])
491
492
gen_optimizer = tf.keras.optimizers.Adam(
493
learning_rate=generator_lr_fn,
494
amsgrad=config["generator_optimizer_params"]["amsgrad"],
495
)
496
dis_optimizer = tf.keras.optimizers.Adam(
497
learning_rate=discriminator_lr_fn,
498
amsgrad=config["discriminator_optimizer_params"]["amsgrad"],
499
)
500
501
trainer.compile(
502
gen_model=generator,
503
dis_model=discriminator,
504
gen_optimizer=gen_optimizer,
505
dis_optimizer=dis_optimizer,
506
pqmf=pqmf,
507
)
508
509
# start training
510
try:
511
trainer.fit(
512
train_dataset,
513
valid_dataset,
514
saved_path=os.path.join(config["outdir"], "checkpoints/"),
515
resume=args.resume,
516
)
517
except KeyboardInterrupt:
518
trainer.save_checkpoint()
519
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
520
521
522
if __name__ == "__main__":
523
main()
524
525