Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/parallel_wavegan/train_parallel_wavegan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Train ParallelWavegan."""
16
17
import tensorflow as tf
18
19
physical_devices = tf.config.list_physical_devices("GPU")
20
for i in range(len(physical_devices)):
21
tf.config.experimental.set_memory_growth(physical_devices[i], True)
22
23
import sys
24
25
sys.path.append(".")
26
27
import argparse
28
import logging
29
import os
30
import soundfile as sf
31
32
import numpy as np
33
import yaml
34
35
import tensorflow_tts
36
37
from examples.melgan.audio_mel_dataset import AudioMelDataset
38
from examples.melgan.train_melgan import collater
39
40
from tensorflow_tts.configs import (
41
ParallelWaveGANGeneratorConfig,
42
ParallelWaveGANDiscriminatorConfig,
43
)
44
from tensorflow_tts.models import (
45
TFParallelWaveGANGenerator,
46
TFParallelWaveGANDiscriminator,
47
)
48
49
from tensorflow_tts.trainers import GanBasedTrainer
50
from tensorflow_tts.losses import TFMultiResolutionSTFT
51
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
52
53
from tensorflow_addons.optimizers import RectifiedAdam
54
55
56
class ParallelWaveganTrainer(GanBasedTrainer):
57
"""ParallelWaveGAN Trainer class based on GanBasedTrainer."""
58
59
def __init__(
60
self,
61
config,
62
strategy,
63
steps=0,
64
epochs=0,
65
is_generator_mixed_precision=False,
66
is_discriminator_mixed_precision=False,
67
):
68
"""Initialize trainer.
69
70
Args:
71
steps (int): Initial global steps.
72
epochs (int): Initial global epochs.
73
config (dict): Config dict loaded from yaml format configuration file.
74
is_generator_mixed_precision (bool): Use mixed precision for generator or not.
75
is_discriminator_mixed_precision (bool): Use mixed precision for discriminator or not.
76
77
"""
78
super(ParallelWaveganTrainer, self).__init__(
79
config=config,
80
steps=steps,
81
epochs=epochs,
82
strategy=strategy,
83
is_generator_mixed_precision=is_generator_mixed_precision,
84
is_discriminator_mixed_precision=is_discriminator_mixed_precision,
85
)
86
87
self.list_metrics_name = [
88
"adversarial_loss",
89
"gen_loss",
90
"real_loss",
91
"fake_loss",
92
"dis_loss",
93
"spectral_convergence_loss",
94
"log_magnitude_loss",
95
]
96
97
self.init_train_eval_metrics(self.list_metrics_name)
98
self.reset_states_train()
99
self.reset_states_eval()
100
101
def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer):
102
super().compile(gen_model, dis_model, gen_optimizer, dis_optimizer)
103
# define loss
104
self.stft_loss = TFMultiResolutionSTFT(**self.config["stft_loss_params"])
105
self.mse_loss = tf.keras.losses.MeanSquaredError(
106
reduction=tf.keras.losses.Reduction.NONE
107
)
108
self.mae_loss = tf.keras.losses.MeanAbsoluteError(
109
reduction=tf.keras.losses.Reduction.NONE
110
)
111
112
def compute_per_example_generator_losses(self, batch, outputs):
113
"""Compute per example generator losses and return dict_metrics_losses
114
Note that all element of the loss MUST has a shape [batch_size] and
115
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
116
117
Args:
118
batch: dictionary batch input return from dataloader
119
outputs: outputs of the model
120
121
Returns:
122
per_example_losses: per example losses for each GPU, shape [B]
123
dict_metrics_losses: dictionary loss.
124
"""
125
dict_metrics_losses = {}
126
per_example_losses = 0.0
127
128
audios = batch["audios"]
129
y_hat = outputs
130
131
# calculate multi-resolution stft loss
132
sc_loss, mag_loss = calculate_2d_loss(
133
audios, tf.squeeze(y_hat, -1), self.stft_loss
134
)
135
gen_loss = 0.5 * (sc_loss + mag_loss)
136
137
if self.steps >= self.config["discriminator_train_start_steps"]:
138
p_hat = self._discriminator(y_hat)
139
p = self._discriminator(tf.expand_dims(audios, 2))
140
adv_loss = 0.0
141
adv_loss += calculate_3d_loss(
142
tf.ones_like(p_hat), p_hat, loss_fn=self.mse_loss
143
)
144
gen_loss += self.config["lambda_adv"] * adv_loss
145
146
# update dict_metrics_losses
147
dict_metrics_losses.update({"adversarial_loss": adv_loss})
148
149
dict_metrics_losses.update({"gen_loss": gen_loss})
150
dict_metrics_losses.update({"spectral_convergence_loss": sc_loss})
151
dict_metrics_losses.update({"log_magnitude_loss": mag_loss})
152
153
per_example_losses = gen_loss
154
return per_example_losses, dict_metrics_losses
155
156
def compute_per_example_discriminator_losses(self, batch, gen_outputs):
157
audios = batch["audios"]
158
y_hat = gen_outputs
159
160
y = tf.expand_dims(audios, 2)
161
p = self._discriminator(y)
162
p_hat = self._discriminator(y_hat)
163
164
real_loss = 0.0
165
fake_loss = 0.0
166
167
real_loss += calculate_3d_loss(tf.ones_like(p), p, loss_fn=self.mse_loss)
168
fake_loss += calculate_3d_loss(
169
tf.zeros_like(p_hat), p_hat, loss_fn=self.mse_loss
170
)
171
172
dis_loss = real_loss + fake_loss
173
174
# calculate per_example_losses and dict_metrics_losses
175
per_example_losses = dis_loss
176
177
dict_metrics_losses = {
178
"real_loss": real_loss,
179
"fake_loss": fake_loss,
180
"dis_loss": dis_loss,
181
}
182
183
return per_example_losses, dict_metrics_losses
184
185
def generate_and_save_intermediate_result(self, batch):
186
"""Generate and save intermediate result."""
187
import matplotlib.pyplot as plt
188
189
# generate
190
y_batch_ = self.one_step_predict(batch)
191
y_batch = batch["audios"]
192
utt_ids = batch["utt_ids"]
193
194
# convert to tensor.
195
# here we just take a sample at first replica.
196
try:
197
y_batch_ = y_batch_.values[0].numpy()
198
y_batch = y_batch.values[0].numpy()
199
utt_ids = utt_ids.values[0].numpy()
200
except Exception:
201
y_batch_ = y_batch_.numpy()
202
y_batch = y_batch.numpy()
203
utt_ids = utt_ids.numpy()
204
205
# check directory
206
dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")
207
if not os.path.exists(dirname):
208
os.makedirs(dirname)
209
210
for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 0):
211
# convert to ndarray
212
y, y_ = tf.reshape(y, [-1]).numpy(), tf.reshape(y_, [-1]).numpy()
213
214
# plit figure and save it
215
utt_id = utt_ids[idx]
216
figname = os.path.join(dirname, f"{utt_id}.png")
217
plt.subplot(2, 1, 1)
218
plt.plot(y)
219
plt.title("groundtruth speech")
220
plt.subplot(2, 1, 2)
221
plt.plot(y_)
222
plt.title(f"generated speech @ {self.steps} steps")
223
plt.tight_layout()
224
plt.savefig(figname)
225
plt.close()
226
227
# save as wavefile
228
y = np.clip(y, -1, 1)
229
y_ = np.clip(y_, -1, 1)
230
sf.write(
231
figname.replace(".png", "_ref.wav"),
232
y,
233
self.config["sampling_rate"],
234
"PCM_16",
235
)
236
sf.write(
237
figname.replace(".png", "_gen.wav"),
238
y_,
239
self.config["sampling_rate"],
240
"PCM_16",
241
)
242
243
244
def main():
245
"""Run training process."""
246
parser = argparse.ArgumentParser(
247
description="Train ParallelWaveGan (See detail in tensorflow_tts/examples/parallel_wavegan/train_parallel_wavegan.py)"
248
)
249
parser.add_argument(
250
"--train-dir",
251
default=None,
252
type=str,
253
help="directory including training data. ",
254
)
255
parser.add_argument(
256
"--dev-dir",
257
default=None,
258
type=str,
259
help="directory including development data. ",
260
)
261
parser.add_argument(
262
"--use-norm", default=1, type=int, help="use norm mels for training or raw."
263
)
264
parser.add_argument(
265
"--outdir", type=str, required=True, help="directory to save checkpoints."
266
)
267
parser.add_argument(
268
"--config", type=str, required=True, help="yaml format configuration file."
269
)
270
parser.add_argument(
271
"--resume",
272
default="",
273
type=str,
274
nargs="?",
275
help='checkpoint file path to resume training. (default="")',
276
)
277
parser.add_argument(
278
"--verbose",
279
type=int,
280
default=1,
281
help="logging level. higher is more logging. (default=1)",
282
)
283
parser.add_argument(
284
"--generator_mixed_precision",
285
default=0,
286
type=int,
287
help="using mixed precision for generator or not.",
288
)
289
parser.add_argument(
290
"--discriminator_mixed_precision",
291
default=0,
292
type=int,
293
help="using mixed precision for discriminator or not.",
294
)
295
args = parser.parse_args()
296
297
# return strategy
298
STRATEGY = return_strategy()
299
300
# set mixed precision config
301
if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:
302
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
303
304
args.generator_mixed_precision = bool(args.generator_mixed_precision)
305
args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)
306
307
args.use_norm = bool(args.use_norm)
308
309
# set logger
310
if args.verbose > 1:
311
logging.basicConfig(
312
level=logging.DEBUG,
313
stream=sys.stdout,
314
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
315
)
316
elif args.verbose > 0:
317
logging.basicConfig(
318
level=logging.INFO,
319
stream=sys.stdout,
320
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
321
)
322
else:
323
logging.basicConfig(
324
level=logging.WARN,
325
stream=sys.stdout,
326
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
327
)
328
logging.warning("Skip DEBUG/INFO messages")
329
330
# check directory existence
331
if not os.path.exists(args.outdir):
332
os.makedirs(args.outdir)
333
334
# check arguments
335
if args.train_dir is None:
336
raise ValueError("Please specify --train-dir")
337
if args.dev_dir is None:
338
raise ValueError("Please specify either --valid-dir")
339
340
# load and save config
341
with open(args.config) as f:
342
config = yaml.load(f, Loader=yaml.Loader)
343
config.update(vars(args))
344
config["version"] = tensorflow_tts.__version__
345
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
346
yaml.dump(config, f, Dumper=yaml.Dumper)
347
for key, value in config.items():
348
logging.info(f"{key} = {value}")
349
350
# get dataset
351
if config["remove_short_samples"]:
352
mel_length_threshold = config["batch_max_steps"] // config[
353
"hop_size"
354
] + 2 * config["parallel_wavegan_generator_params"].get("aux_context_window", 0)
355
else:
356
mel_length_threshold = None
357
358
if config["format"] == "npy":
359
audio_query = "*-wave.npy"
360
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
361
audio_load_fn = np.load
362
mel_load_fn = np.load
363
else:
364
raise ValueError("Only npy are supported.")
365
366
# define train/valid dataset
367
train_dataset = AudioMelDataset(
368
root_dir=args.train_dir,
369
audio_query=audio_query,
370
mel_query=mel_query,
371
audio_load_fn=audio_load_fn,
372
mel_load_fn=mel_load_fn,
373
mel_length_threshold=mel_length_threshold,
374
).create(
375
is_shuffle=config["is_shuffle"],
376
map_fn=lambda items: collater(
377
items,
378
batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),
379
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
380
),
381
allow_cache=config["allow_cache"],
382
batch_size=config["batch_size"]
383
* STRATEGY.num_replicas_in_sync
384
* config["gradient_accumulation_steps"],
385
)
386
387
valid_dataset = AudioMelDataset(
388
root_dir=args.dev_dir,
389
audio_query=audio_query,
390
mel_query=mel_query,
391
audio_load_fn=audio_load_fn,
392
mel_load_fn=mel_load_fn,
393
mel_length_threshold=mel_length_threshold,
394
).create(
395
is_shuffle=config["is_shuffle"],
396
map_fn=lambda items: collater(
397
items,
398
batch_max_steps=tf.constant(
399
config["batch_max_steps_valid"], dtype=tf.int32
400
),
401
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
402
),
403
allow_cache=config["allow_cache"],
404
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
405
)
406
407
# define trainer
408
trainer = ParallelWaveganTrainer(
409
steps=0,
410
epochs=0,
411
config=config,
412
strategy=STRATEGY,
413
is_generator_mixed_precision=args.generator_mixed_precision,
414
is_discriminator_mixed_precision=args.discriminator_mixed_precision,
415
)
416
417
with STRATEGY.scope():
418
# define generator and discriminator
419
generator = TFParallelWaveGANGenerator(
420
ParallelWaveGANGeneratorConfig(
421
**config["parallel_wavegan_generator_params"]
422
),
423
name="parallel_wavegan_generator",
424
)
425
426
discriminator = TFParallelWaveGANDiscriminator(
427
ParallelWaveGANDiscriminatorConfig(
428
**config["parallel_wavegan_discriminator_params"]
429
),
430
name="parallel_wavegan_discriminator",
431
)
432
433
# dummy input to build model.
434
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
435
y_hat = generator(fake_mels)
436
discriminator(y_hat)
437
438
generator.summary()
439
discriminator.summary()
440
441
# define optimizer
442
generator_lr_fn = getattr(
443
tf.keras.optimizers.schedules, config["generator_optimizer_params"]["lr_fn"]
444
)(**config["generator_optimizer_params"]["lr_params"])
445
discriminator_lr_fn = getattr(
446
tf.keras.optimizers.schedules,
447
config["discriminator_optimizer_params"]["lr_fn"],
448
)(**config["discriminator_optimizer_params"]["lr_params"])
449
450
gen_optimizer = RectifiedAdam(learning_rate=generator_lr_fn, amsgrad=False)
451
dis_optimizer = RectifiedAdam(learning_rate=discriminator_lr_fn, amsgrad=False)
452
453
trainer.compile(
454
gen_model=generator,
455
dis_model=discriminator,
456
gen_optimizer=gen_optimizer,
457
dis_optimizer=dis_optimizer,
458
)
459
460
# start training
461
try:
462
trainer.fit(
463
train_dataset,
464
valid_dataset,
465
saved_path=os.path.join(config["outdir"], "checkpoints/"),
466
resume=args.resume,
467
)
468
except KeyboardInterrupt:
469
trainer.save_checkpoint()
470
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
471
472
473
if __name__ == "__main__":
474
main()
475
476