Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/fastspeech2_libritts/train_fastspeech2.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Train FastSpeech2."""
16
17
import tensorflow as tf
18
19
physical_devices = tf.config.list_physical_devices("GPU")
20
for i in range(len(physical_devices)):
21
tf.config.experimental.set_memory_growth(physical_devices[i], True)
22
23
import sys
24
25
sys.path.append(".")
26
27
import argparse
28
import logging
29
import os
30
31
import numpy as np
32
import yaml
33
import json
34
35
import tensorflow_tts
36
from examples.fastspeech2_libritts.fastspeech2_dataset import (
37
CharactorDurationF0EnergyMelDataset,
38
)
39
from tensorflow_tts.configs import FastSpeech2Config
40
from tensorflow_tts.models import TFFastSpeech2
41
from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp
42
from tensorflow_tts.trainers import Seq2SeqBasedTrainer
43
from tensorflow_tts.utils import (
44
calculate_2d_loss,
45
calculate_3d_loss,
46
return_strategy,
47
TFGriffinLim,
48
)
49
50
51
class FastSpeech2Trainer(Seq2SeqBasedTrainer):
52
"""FastSpeech2 Trainer class based on FastSpeechTrainer."""
53
54
def __init__(
55
self,
56
config,
57
strategy,
58
steps=0,
59
epochs=0,
60
is_mixed_precision=False,
61
stats_path: str = "",
62
dataset_config: str = "",
63
):
64
"""Initialize trainer.
65
Args:
66
steps (int): Initial global steps.
67
epochs (int): Initial global epochs.
68
config (dict): Config dict loaded from yaml format configuration file.
69
is_mixed_precision (bool): Use mixed precision or not.
70
"""
71
super(FastSpeech2Trainer, self).__init__(
72
steps=steps,
73
epochs=epochs,
74
config=config,
75
strategy=strategy,
76
is_mixed_precision=is_mixed_precision,
77
)
78
# define metrics to aggregates data and use tf.summary logs them
79
self.list_metrics_name = [
80
"duration_loss",
81
"f0_loss",
82
"energy_loss",
83
"mel_loss_before",
84
"mel_loss_after",
85
]
86
self.init_train_eval_metrics(self.list_metrics_name)
87
self.reset_states_train()
88
self.reset_states_eval()
89
self.use_griffin = config.get("use_griffin", False)
90
self.griffin_lim_tf = None
91
if self.use_griffin:
92
logging.info(
93
f"Load griff stats from {stats_path} and config from {dataset_config}"
94
)
95
self.griff_conf = yaml.load(open(dataset_config), Loader=yaml.Loader)
96
self.prepare_grim(stats_path, self.griff_conf)
97
98
def prepare_grim(self, stats_path, config):
99
if not stats_path:
100
raise KeyError("stats path need to exist")
101
self.griffin_lim_tf = TFGriffinLim(stats_path, config)
102
103
def compile(self, model, optimizer):
104
super().compile(model, optimizer)
105
self.mse = tf.keras.losses.MeanSquaredError(
106
reduction=tf.keras.losses.Reduction.NONE
107
)
108
self.mae = tf.keras.losses.MeanAbsoluteError(
109
reduction=tf.keras.losses.Reduction.NONE
110
)
111
112
def compute_per_example_losses(self, batch, outputs):
113
"""Compute per example losses and return dict_metrics_losses
114
Note that all element of the loss MUST has a shape [batch_size] and
115
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
116
117
Args:
118
batch: dictionary batch input return from dataloader
119
outputs: outputs of the model
120
121
Returns:
122
per_example_losses: per example losses for each GPU, shape [B]
123
dict_metrics_losses: dictionary loss.
124
"""
125
mel_before, mel_after, duration_outputs, f0_outputs, energy_outputs = outputs
126
127
log_duration = tf.math.log(
128
tf.cast(tf.math.add(batch["duration_gts"], 1), tf.float32)
129
)
130
duration_loss = calculate_2d_loss(log_duration, duration_outputs, self.mse)
131
f0_loss = calculate_2d_loss(batch["f0_gts"], f0_outputs, self.mse)
132
energy_loss = calculate_2d_loss(batch["energy_gts"], energy_outputs, self.mse)
133
mel_loss_before = calculate_3d_loss(batch["mel_gts"], mel_before, self.mae)
134
mel_loss_after = calculate_3d_loss(batch["mel_gts"], mel_after, self.mae)
135
136
per_example_losses = (
137
duration_loss + f0_loss + energy_loss + mel_loss_before + mel_loss_after
138
)
139
140
dict_metrics_losses = {
141
"duration_loss": duration_loss,
142
"f0_loss": f0_loss,
143
"energy_loss": energy_loss,
144
"mel_loss_before": mel_loss_before,
145
"mel_loss_after": mel_loss_after,
146
}
147
148
return per_example_losses, dict_metrics_losses
149
150
def generate_and_save_intermediate_result(self, batch):
151
"""Generate and save intermediate result."""
152
import matplotlib.pyplot as plt
153
154
# predict with tf.function.
155
outputs = self.one_step_predict(batch)
156
157
mels_before, mels_after, *_ = outputs
158
mel_gts = batch["mel_gts"]
159
utt_ids = batch["utt_ids"]
160
161
# convert to tensor.
162
# here we just take a sample at first replica.
163
try:
164
mels_before = mels_before.values[0].numpy()
165
mels_after = mels_after.values[0].numpy()
166
mel_gts = mel_gts.values[0].numpy()
167
utt_ids = utt_ids.values[0].numpy()
168
except Exception:
169
mels_before = mels_before.numpy()
170
mels_after = mels_after.numpy()
171
mel_gts = mel_gts.numpy()
172
utt_ids = utt_ids.numpy()
173
174
# check directory
175
if self.use_griffin:
176
griff_dir_name = os.path.join(
177
self.config["outdir"], f"predictions/{self.steps}_wav"
178
)
179
if not os.path.exists(griff_dir_name):
180
os.makedirs(griff_dir_name)
181
182
dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")
183
if not os.path.exists(dirname):
184
os.makedirs(dirname)
185
186
for idx, (mel_gt, mel_before, mel_after) in enumerate(
187
zip(mel_gts, mels_before, mels_after), 0
188
):
189
190
if self.use_griffin:
191
utt_id = utt_ids[idx]
192
grif_before = self.griffin_lim_tf(
193
tf.reshape(mel_before, [-1, 80])[tf.newaxis, :], n_iter=32
194
)
195
grif_after = self.griffin_lim_tf(
196
tf.reshape(mel_after, [-1, 80])[tf.newaxis, :], n_iter=32
197
)
198
grif_gt = self.griffin_lim_tf(
199
tf.reshape(mel_gt, [-1, 80])[tf.newaxis, :], n_iter=32
200
)
201
self.griffin_lim_tf.save_wav(
202
grif_before, griff_dir_name, f"{utt_id}_before"
203
)
204
self.griffin_lim_tf.save_wav(
205
grif_after, griff_dir_name, f"{utt_id}_after"
206
)
207
self.griffin_lim_tf.save_wav(grif_gt, griff_dir_name, f"{utt_id}_gt")
208
209
utt_id = utt_ids[idx]
210
mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80]
211
mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80]
212
mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80]
213
214
# plit figure and save it
215
figname = os.path.join(dirname, f"{utt_id}.png")
216
fig = plt.figure(figsize=(10, 8))
217
ax1 = fig.add_subplot(311)
218
ax2 = fig.add_subplot(312)
219
ax3 = fig.add_subplot(313)
220
im = ax1.imshow(np.rot90(mel_gt), aspect="auto", interpolation="none")
221
ax1.set_title("Target Mel-Spectrogram")
222
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
223
ax2.set_title("Predicted Mel-before-Spectrogram")
224
im = ax2.imshow(np.rot90(mel_before), aspect="auto", interpolation="none")
225
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
226
ax3.set_title("Predicted Mel-after-Spectrogram")
227
im = ax3.imshow(np.rot90(mel_after), aspect="auto", interpolation="none")
228
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax3)
229
plt.tight_layout()
230
plt.savefig(figname)
231
plt.close()
232
233
234
def main():
235
"""Run training process."""
236
parser = argparse.ArgumentParser(
237
description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)"
238
)
239
parser.add_argument(
240
"--train-dir",
241
default="dump/train",
242
type=str,
243
help="directory including training data. ",
244
)
245
parser.add_argument(
246
"--dev-dir",
247
default="dump/valid",
248
type=str,
249
help="directory including development data. ",
250
)
251
parser.add_argument(
252
"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."
253
)
254
parser.add_argument(
255
"--f0-stat", default="./dump/stats_f0.npy", type=str, help="f0-stat path.",
256
)
257
parser.add_argument(
258
"--energy-stat",
259
default="./dump/stats_energy.npy",
260
type=str,
261
help="energy-stat path.",
262
)
263
parser.add_argument(
264
"--outdir", type=str, required=True, help="directory to save checkpoints."
265
)
266
parser.add_argument(
267
"--config", type=str, required=True, help="yaml format configuration file."
268
)
269
parser.add_argument(
270
"--resume",
271
default="",
272
type=str,
273
nargs="?",
274
help='checkpoint file path to resume training. (default="")',
275
)
276
parser.add_argument(
277
"--verbose",
278
type=int,
279
default=1,
280
help="logging level. higher is more logging. (default=1)",
281
)
282
parser.add_argument(
283
"--mixed_precision",
284
default=1,
285
type=int,
286
help="using mixed precision for generator or not.",
287
)
288
parser.add_argument(
289
"--dataset_config", default="preprocess/libritts_preprocess.yaml", type=str,
290
)
291
parser.add_argument(
292
"--dataset_stats", default="dump/stats.npy", type=str,
293
)
294
parser.add_argument(
295
"--dataset_mapping", default="dump/libritts_mapper.npy", type=str,
296
)
297
parser.add_argument(
298
"--pretrained",
299
default="",
300
type=str,
301
nargs="?",
302
help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",
303
)
304
args = parser.parse_args()
305
306
# return strategy
307
STRATEGY = return_strategy()
308
309
# set mixed precision config
310
if args.mixed_precision == 1:
311
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
312
313
args.mixed_precision = bool(args.mixed_precision)
314
args.use_norm = bool(args.use_norm)
315
316
# set logger
317
if args.verbose > 1:
318
logging.basicConfig(
319
level=logging.DEBUG,
320
stream=sys.stdout,
321
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
322
)
323
elif args.verbose > 0:
324
logging.basicConfig(
325
level=logging.INFO,
326
stream=sys.stdout,
327
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
328
)
329
else:
330
logging.basicConfig(
331
level=logging.WARN,
332
stream=sys.stdout,
333
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
334
)
335
logging.warning("Skip DEBUG/INFO messages")
336
337
# check directory existence
338
if not os.path.exists(args.outdir):
339
os.makedirs(args.outdir)
340
341
# check arguments
342
if args.train_dir is None:
343
raise ValueError("Please specify --train-dir")
344
if args.dev_dir is None:
345
raise ValueError("Please specify --valid-dir")
346
347
# load and save config
348
with open(args.config) as f:
349
config = yaml.load(f, Loader=yaml.Loader)
350
config.update(vars(args))
351
config["version"] = tensorflow_tts.__version__
352
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
353
yaml.dump(config, f, Dumper=yaml.Dumper)
354
for key, value in config.items():
355
logging.info(f"{key} = {value}")
356
357
# get dataset
358
if config["remove_short_samples"]:
359
mel_length_threshold = config["mel_length_threshold"]
360
else:
361
mel_length_threshold = None
362
363
if config["format"] == "npy":
364
charactor_query = "*-ids.npy"
365
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
366
duration_query = "*-durations.npy"
367
f0_query = "*-raw-f0.npy"
368
energy_query = "*-raw-energy.npy"
369
else:
370
raise ValueError("Only npy are supported.")
371
372
# load speakers map from dataset map
373
with open(args.dataset_mapping) as f:
374
dataset_mapping = json.load(f)
375
speakers_map = dataset_mapping["speakers_map"]
376
377
# Check n_speakers matches number of speakers in speakers_map
378
n_speakers = config["fastspeech2_params"]["n_speakers"]
379
assert n_speakers == len(
380
speakers_map
381
), f"Number of speakers in dataset does not match n_speakers in config"
382
383
# define train/valid dataset
384
train_dataset = CharactorDurationF0EnergyMelDataset(
385
root_dir=args.train_dir,
386
charactor_query=charactor_query,
387
mel_query=mel_query,
388
duration_query=duration_query,
389
f0_query=f0_query,
390
energy_query=energy_query,
391
f0_stat=args.f0_stat,
392
energy_stat=args.energy_stat,
393
mel_length_threshold=mel_length_threshold,
394
speakers_map=speakers_map,
395
).create(
396
is_shuffle=config["is_shuffle"],
397
allow_cache=config["allow_cache"],
398
batch_size=config["batch_size"]
399
* STRATEGY.num_replicas_in_sync
400
* config["gradient_accumulation_steps"],
401
)
402
403
valid_dataset = CharactorDurationF0EnergyMelDataset(
404
root_dir=args.dev_dir,
405
charactor_query=charactor_query,
406
mel_query=mel_query,
407
duration_query=duration_query,
408
f0_query=f0_query,
409
energy_query=energy_query,
410
f0_stat=args.f0_stat,
411
energy_stat=args.energy_stat,
412
mel_length_threshold=mel_length_threshold,
413
speakers_map=speakers_map,
414
).create(
415
is_shuffle=config["is_shuffle"],
416
allow_cache=config["allow_cache"],
417
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
418
)
419
420
# define trainer
421
trainer = FastSpeech2Trainer(
422
config=config,
423
strategy=STRATEGY,
424
steps=0,
425
epochs=0,
426
is_mixed_precision=args.mixed_precision,
427
stats_path=args.dataset_stats,
428
dataset_config=args.dataset_config,
429
)
430
431
with STRATEGY.scope():
432
# define model
433
fastspeech = TFFastSpeech2(
434
config=FastSpeech2Config(**config["fastspeech2_params"])
435
)
436
fastspeech._build()
437
fastspeech.summary()
438
439
if len(args.pretrained) > 1:
440
fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
441
logging.info(
442
f"Successfully loaded pretrained weight from {args.pretrained}."
443
)
444
445
# AdamW for fastspeech
446
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
447
initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],
448
decay_steps=config["optimizer_params"]["decay_steps"],
449
end_learning_rate=config["optimizer_params"]["end_learning_rate"],
450
)
451
452
learning_rate_fn = WarmUp(
453
initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],
454
decay_schedule_fn=learning_rate_fn,
455
warmup_steps=int(
456
config["train_max_steps"]
457
* config["optimizer_params"]["warmup_proportion"]
458
),
459
)
460
461
optimizer = AdamWeightDecay(
462
learning_rate=learning_rate_fn,
463
weight_decay_rate=config["optimizer_params"]["weight_decay"],
464
beta_1=0.9,
465
beta_2=0.98,
466
epsilon=1e-6,
467
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
468
)
469
470
_ = optimizer.iterations
471
472
# compile trainer
473
trainer.compile(model=fastspeech, optimizer=optimizer)
474
475
# start training
476
try:
477
trainer.fit(
478
train_dataset,
479
valid_dataset,
480
saved_path=os.path.join(config["outdir"], "checkpoints/"),
481
resume=args.resume,
482
)
483
except KeyboardInterrupt:
484
trainer.save_checkpoint()
485
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
486
487
488
if __name__ == "__main__":
489
main()
490
491