Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/melgan/train_melgan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Train MelGAN."""
16
17
import tensorflow as tf
18
19
physical_devices = tf.config.list_physical_devices("GPU")
20
for i in range(len(physical_devices)):
21
tf.config.experimental.set_memory_growth(physical_devices[i], True)
22
23
import sys
24
25
sys.path.append(".")
26
27
import argparse
28
import logging
29
import os
30
31
import numpy as np
32
import soundfile as sf
33
import yaml
34
from tqdm import tqdm
35
36
import tensorflow_tts
37
import tensorflow_tts.configs.melgan as MELGAN_CONFIG
38
from examples.melgan.audio_mel_dataset import AudioMelDataset
39
from tensorflow_tts.losses import TFMelSpectrogram
40
from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator
41
from tensorflow_tts.trainers import GanBasedTrainer
42
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
43
44
45
class MelganTrainer(GanBasedTrainer):
46
"""Melgan Trainer class based on GanBasedTrainer."""
47
48
def __init__(
49
self,
50
config,
51
strategy,
52
steps=0,
53
epochs=0,
54
is_generator_mixed_precision=False,
55
is_discriminator_mixed_precision=False,
56
):
57
"""Initialize trainer.
58
59
Args:
60
steps (int): Initial global steps.
61
epochs (int): Initial global epochs.
62
config (dict): Config dict loaded from yaml format configuration file.
63
is_generator_mixed_precision (bool): Use mixed precision for generator or not.
64
is_discriminator_mixed_precision (bool): Use mixed precision for discriminator or not.
65
66
67
"""
68
super(MelganTrainer, self).__init__(
69
steps,
70
epochs,
71
config,
72
strategy,
73
is_generator_mixed_precision,
74
is_discriminator_mixed_precision,
75
)
76
# define metrics to aggregates data and use tf.summary logs them
77
self.list_metrics_name = [
78
"adversarial_loss",
79
"fm_loss",
80
"gen_loss",
81
"real_loss",
82
"fake_loss",
83
"dis_loss",
84
"mels_spectrogram_loss",
85
]
86
self.init_train_eval_metrics(self.list_metrics_name)
87
self.reset_states_train()
88
self.reset_states_eval()
89
90
self.config = config
91
92
def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer):
93
super().compile(gen_model, dis_model, gen_optimizer, dis_optimizer)
94
# define loss
95
self.mse_loss = tf.keras.losses.MeanSquaredError(
96
reduction=tf.keras.losses.Reduction.NONE
97
)
98
self.mae_loss = tf.keras.losses.MeanAbsoluteError(
99
reduction=tf.keras.losses.Reduction.NONE
100
)
101
self.mels_loss = TFMelSpectrogram()
102
103
def compute_per_example_generator_losses(self, batch, outputs):
104
"""Compute per example generator losses and return dict_metrics_losses
105
Note that all element of the loss MUST has a shape [batch_size] and
106
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
107
108
Args:
109
batch: dictionary batch input return from dataloader
110
outputs: outputs of the model
111
112
Returns:
113
per_example_losses: per example losses for each GPU, shape [B]
114
dict_metrics_losses: dictionary loss.
115
"""
116
audios = batch["audios"]
117
y_hat = outputs
118
119
p_hat = self._discriminator(y_hat)
120
p = self._discriminator(tf.expand_dims(audios, 2))
121
adv_loss = 0.0
122
for i in range(len(p_hat)):
123
adv_loss += calculate_3d_loss(
124
tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss
125
)
126
adv_loss /= i + 1
127
128
# define feature-matching loss
129
fm_loss = 0.0
130
for i in range(len(p_hat)):
131
for j in range(len(p_hat[i]) - 1):
132
fm_loss += calculate_3d_loss(
133
p[i][j], p_hat[i][j], loss_fn=self.mae_loss
134
)
135
fm_loss /= (i + 1) * (j + 1)
136
adv_loss += self.config["lambda_feat_match"] * fm_loss
137
138
per_example_losses = adv_loss
139
140
dict_metrics_losses = {
141
"adversarial_loss": adv_loss,
142
"fm_loss": fm_loss,
143
"gen_loss": adv_loss,
144
"mels_spectrogram_loss": calculate_2d_loss(
145
audios, tf.squeeze(y_hat, -1), loss_fn=self.mels_loss
146
),
147
}
148
149
return per_example_losses, dict_metrics_losses
150
151
def compute_per_example_discriminator_losses(self, batch, gen_outputs):
152
audios = batch["audios"]
153
y_hat = gen_outputs
154
155
y = tf.expand_dims(audios, 2)
156
p = self._discriminator(y)
157
p_hat = self._discriminator(y_hat)
158
159
real_loss = 0.0
160
fake_loss = 0.0
161
for i in range(len(p)):
162
real_loss += calculate_3d_loss(
163
tf.ones_like(p[i][-1]), p[i][-1], loss_fn=self.mse_loss
164
)
165
fake_loss += calculate_3d_loss(
166
tf.zeros_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss
167
)
168
real_loss /= i + 1
169
fake_loss /= i + 1
170
dis_loss = real_loss + fake_loss
171
172
# calculate per_example_losses and dict_metrics_losses
173
per_example_losses = dis_loss
174
175
dict_metrics_losses = {
176
"real_loss": real_loss,
177
"fake_loss": fake_loss,
178
"dis_loss": dis_loss,
179
}
180
181
return per_example_losses, dict_metrics_losses
182
183
def generate_and_save_intermediate_result(self, batch):
184
"""Generate and save intermediate result."""
185
import matplotlib.pyplot as plt
186
187
# generate
188
y_batch_ = self.one_step_predict(batch)
189
y_batch = batch["audios"]
190
utt_ids = batch["utt_ids"]
191
192
# convert to tensor.
193
# here we just take a sample at first replica.
194
try:
195
y_batch_ = y_batch_.values[0].numpy()
196
y_batch = y_batch.values[0].numpy()
197
utt_ids = utt_ids.values[0].numpy()
198
except Exception:
199
y_batch_ = y_batch_.numpy()
200
y_batch = y_batch.numpy()
201
utt_ids = utt_ids.numpy()
202
203
# check directory
204
dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")
205
if not os.path.exists(dirname):
206
os.makedirs(dirname)
207
208
for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 0):
209
# convert to ndarray
210
y, y_ = tf.reshape(y, [-1]).numpy(), tf.reshape(y_, [-1]).numpy()
211
212
# plit figure and save it
213
utt_id = utt_ids[idx]
214
figname = os.path.join(dirname, f"{utt_id}.png")
215
plt.subplot(2, 1, 1)
216
plt.plot(y)
217
plt.title("groundtruth speech")
218
plt.subplot(2, 1, 2)
219
plt.plot(y_)
220
plt.title(f"generated speech @ {self.steps} steps")
221
plt.tight_layout()
222
plt.savefig(figname)
223
plt.close()
224
225
# save as wavefile
226
y = np.clip(y, -1, 1)
227
y_ = np.clip(y_, -1, 1)
228
sf.write(
229
figname.replace(".png", "_ref.wav"),
230
y,
231
self.config["sampling_rate"],
232
"PCM_16",
233
)
234
sf.write(
235
figname.replace(".png", "_gen.wav"),
236
y_,
237
self.config["sampling_rate"],
238
"PCM_16",
239
)
240
241
242
def collater(
243
items,
244
batch_max_steps=tf.constant(8192, dtype=tf.int32),
245
hop_size=tf.constant(256, dtype=tf.int32),
246
):
247
"""Initialize collater (mapping function) for Tensorflow Audio-Mel Dataset.
248
249
Args:
250
batch_max_steps (int): The maximum length of input signal in batch.
251
hop_size (int): Hop size of auxiliary features.
252
253
"""
254
audio, mel = items["audios"], items["mels"]
255
256
if batch_max_steps is None:
257
batch_max_steps = (tf.shape(audio)[0] // hop_size) * hop_size
258
259
batch_max_frames = batch_max_steps // hop_size
260
if len(audio) < len(mel) * hop_size:
261
audio = tf.pad(audio, [[0, len(mel) * hop_size - len(audio)]])
262
263
if len(mel) > batch_max_frames:
264
# randomly pickup with the batch_max_steps length of the part
265
interval_start = 0
266
interval_end = len(mel) - batch_max_frames
267
start_frame = tf.random.uniform(
268
shape=[], minval=interval_start, maxval=interval_end, dtype=tf.int32
269
)
270
start_step = start_frame * hop_size
271
audio = audio[start_step : start_step + batch_max_steps]
272
mel = mel[start_frame : start_frame + batch_max_frames, :]
273
else:
274
audio = tf.pad(audio, [[0, batch_max_steps - len(audio)]])
275
mel = tf.pad(mel, [[0, batch_max_frames - len(mel)], [0, 0]])
276
277
items = {
278
"utt_ids": items["utt_ids"],
279
"audios": audio,
280
"mels": mel,
281
"mel_lengths": len(mel),
282
"audio_lengths": len(audio),
283
}
284
285
return items
286
287
288
def main():
289
"""Run training process."""
290
parser = argparse.ArgumentParser(
291
description="Train MelGAN (See detail in tensorflow_tts/bin/train-melgan.py)"
292
)
293
parser.add_argument(
294
"--train-dir",
295
default=None,
296
type=str,
297
help="directory including training data. ",
298
)
299
parser.add_argument(
300
"--dev-dir",
301
default=None,
302
type=str,
303
help="directory including development data. ",
304
)
305
parser.add_argument(
306
"--use-norm", default=1, type=int, help="use norm mels for training or raw."
307
)
308
parser.add_argument(
309
"--outdir", type=str, required=True, help="directory to save checkpoints."
310
)
311
parser.add_argument(
312
"--config", type=str, required=True, help="yaml format configuration file."
313
)
314
parser.add_argument(
315
"--resume",
316
default="",
317
type=str,
318
nargs="?",
319
help='checkpoint file path to resume training. (default="")',
320
)
321
parser.add_argument(
322
"--verbose",
323
type=int,
324
default=1,
325
help="logging level. higher is more logging. (default=1)",
326
)
327
parser.add_argument(
328
"--generator_mixed_precision",
329
default=0,
330
type=int,
331
help="using mixed precision for generator or not.",
332
)
333
parser.add_argument(
334
"--discriminator_mixed_precision",
335
default=0,
336
type=int,
337
help="using mixed precision for discriminator or not.",
338
)
339
parser.add_argument(
340
"--pretrained",
341
default="",
342
type=str,
343
nargs="?",
344
help="path of .h5 melgan generator to load weights from",
345
)
346
args = parser.parse_args()
347
348
# return strategy
349
STRATEGY = return_strategy()
350
351
# set mixed precision config
352
if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:
353
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
354
355
args.generator_mixed_precision = bool(args.generator_mixed_precision)
356
args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)
357
358
args.use_norm = bool(args.use_norm)
359
360
# set logger
361
if args.verbose > 1:
362
logging.basicConfig(
363
level=logging.DEBUG,
364
stream=sys.stdout,
365
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
366
)
367
elif args.verbose > 0:
368
logging.basicConfig(
369
level=logging.INFO,
370
stream=sys.stdout,
371
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
372
)
373
else:
374
logging.basicConfig(
375
level=logging.WARN,
376
stream=sys.stdout,
377
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
378
)
379
logging.warning("Skip DEBUG/INFO messages")
380
381
# check directory existence
382
if not os.path.exists(args.outdir):
383
os.makedirs(args.outdir)
384
385
# check arguments
386
if args.train_dir is None:
387
raise ValueError("Please specify --train-dir")
388
if args.dev_dir is None:
389
raise ValueError("Please specify either --valid-dir")
390
391
# load and save config
392
with open(args.config) as f:
393
config = yaml.load(f, Loader=yaml.Loader)
394
config.update(vars(args))
395
config["version"] = tensorflow_tts.__version__
396
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
397
yaml.dump(config, f, Dumper=yaml.Dumper)
398
for key, value in config.items():
399
logging.info(f"{key} = {value}")
400
401
# get dataset
402
if config["remove_short_samples"]:
403
mel_length_threshold = config["batch_max_steps"] // config[
404
"hop_size"
405
] + 2 * config["melgan_generator_params"].get("aux_context_window", 0)
406
else:
407
mel_length_threshold = None
408
409
if config["format"] == "npy":
410
audio_query = "*-wave.npy"
411
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
412
audio_load_fn = np.load
413
mel_load_fn = np.load
414
else:
415
raise ValueError("Only npy are supported.")
416
417
# define train/valid dataset
418
train_dataset = AudioMelDataset(
419
root_dir=args.train_dir,
420
audio_query=audio_query,
421
mel_query=mel_query,
422
audio_load_fn=audio_load_fn,
423
mel_load_fn=mel_load_fn,
424
mel_length_threshold=mel_length_threshold,
425
).create(
426
is_shuffle=config["is_shuffle"],
427
map_fn=lambda items: collater(
428
items,
429
batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),
430
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
431
),
432
allow_cache=config["allow_cache"],
433
batch_size=config["batch_size"]
434
* STRATEGY.num_replicas_in_sync
435
* config["gradient_accumulation_steps"],
436
)
437
438
valid_dataset = AudioMelDataset(
439
root_dir=args.dev_dir,
440
audio_query=audio_query,
441
mel_query=mel_query,
442
audio_load_fn=audio_load_fn,
443
mel_load_fn=mel_load_fn,
444
mel_length_threshold=mel_length_threshold,
445
).create(
446
is_shuffle=config["is_shuffle"],
447
map_fn=lambda items: collater(
448
items,
449
batch_max_steps=tf.constant(
450
config["batch_max_steps_valid"], dtype=tf.int32
451
),
452
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
453
),
454
allow_cache=config["allow_cache"],
455
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
456
)
457
458
# define trainer
459
trainer = MelganTrainer(
460
steps=0,
461
epochs=0,
462
config=config,
463
strategy=STRATEGY,
464
is_generator_mixed_precision=args.generator_mixed_precision,
465
is_discriminator_mixed_precision=args.discriminator_mixed_precision,
466
)
467
468
# define generator and discriminator
469
with STRATEGY.scope():
470
generator = TFMelGANGenerator(
471
MELGAN_CONFIG.MelGANGeneratorConfig(**config["melgan_generator_params"]),
472
name="melgan_generator",
473
)
474
475
discriminator = TFMelGANMultiScaleDiscriminator(
476
MELGAN_CONFIG.MelGANDiscriminatorConfig(
477
**config["melgan_discriminator_params"]
478
),
479
name="melgan_discriminator",
480
)
481
482
# dummy input to build model.
483
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
484
y_hat = generator(fake_mels)
485
discriminator(y_hat)
486
487
if len(args.pretrained) > 1:
488
generator.load_weights(args.pretrained)
489
logging.info(
490
f"Successfully loaded pretrained weight from {args.pretrained}."
491
)
492
493
generator.summary()
494
discriminator.summary()
495
496
gen_optimizer = tf.keras.optimizers.Adam(**config["generator_optimizer_params"])
497
dis_optimizer = tf.keras.optimizers.Adam(
498
**config["discriminator_optimizer_params"]
499
)
500
501
trainer.compile(
502
gen_model=generator,
503
dis_model=discriminator,
504
gen_optimizer=gen_optimizer,
505
dis_optimizer=dis_optimizer,
506
)
507
508
# start training
509
try:
510
trainer.fit(
511
train_dataset,
512
valid_dataset,
513
saved_path=os.path.join(config["outdir"], "checkpoints/"),
514
resume=args.resume,
515
)
516
except KeyboardInterrupt:
517
trainer.save_checkpoint()
518
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
519
520
521
if __name__ == "__main__":
522
main()
523
524