Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/tacotron2/train_tacotron2.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Train Tacotron2."""
16
import tensorflow as tf
17
18
physical_devices = tf.config.list_physical_devices("GPU")
19
for i in range(len(physical_devices)):
20
tf.config.experimental.set_memory_growth(physical_devices[i], True)
21
22
import sys
23
24
sys.path.append(".")
25
26
import argparse
27
import logging
28
import os
29
30
import numpy as np
31
import yaml
32
from tqdm import tqdm
33
34
import tensorflow_tts
35
from examples.tacotron2.tacotron_dataset import CharactorMelDataset
36
from tensorflow_tts.configs.tacotron2 import Tacotron2Config
37
from tensorflow_tts.models import TFTacotron2
38
from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp
39
from tensorflow_tts.trainers import Seq2SeqBasedTrainer
40
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
41
42
43
class Tacotron2Trainer(Seq2SeqBasedTrainer):
44
"""Tacotron2 Trainer class based on Seq2SeqBasedTrainer."""
45
46
def __init__(
47
self,
48
config,
49
strategy,
50
steps=0,
51
epochs=0,
52
is_mixed_precision=False,
53
):
54
"""Initialize trainer.
55
56
Args:
57
steps (int): Initial global steps.
58
epochs (int): Initial global epochs.
59
config (dict): Config dict loaded from yaml format configuration file.
60
is_mixed_precision (bool): Use mixed precision or not.
61
62
"""
63
super(Tacotron2Trainer, self).__init__(
64
steps=steps,
65
epochs=epochs,
66
config=config,
67
strategy=strategy,
68
is_mixed_precision=is_mixed_precision,
69
)
70
# define metrics to aggregates data and use tf.summary logs them
71
self.list_metrics_name = [
72
"stop_token_loss",
73
"mel_loss_before",
74
"mel_loss_after",
75
"guided_attention_loss",
76
]
77
self.init_train_eval_metrics(self.list_metrics_name)
78
self.reset_states_train()
79
self.reset_states_eval()
80
81
self.config = config
82
83
def compile(self, model, optimizer):
84
super().compile(model, optimizer)
85
self.binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
86
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
87
)
88
self.mse = tf.keras.losses.MeanSquaredError(
89
reduction=tf.keras.losses.Reduction.NONE
90
)
91
self.mae = tf.keras.losses.MeanAbsoluteError(
92
reduction=tf.keras.losses.Reduction.NONE
93
)
94
95
def _train_step(self, batch):
96
"""Here we re-define _train_step because apply input_signature make
97
the training progress slower on my experiment. Note that input_signature
98
is apply on based_trainer by default.
99
"""
100
if self._already_apply_input_signature is False:
101
self.one_step_forward = tf.function(
102
self._one_step_forward, experimental_relax_shapes=True
103
)
104
self.one_step_evaluate = tf.function(
105
self._one_step_evaluate, experimental_relax_shapes=True
106
)
107
self.one_step_predict = tf.function(
108
self._one_step_predict, experimental_relax_shapes=True
109
)
110
self._already_apply_input_signature = True
111
112
# run one_step_forward
113
self.one_step_forward(batch)
114
115
# update counts
116
self.steps += 1
117
self.tqdm.update(1)
118
self._check_train_finish()
119
120
def _one_step_evaluate_per_replica(self, batch):
121
"""One step evaluate per GPU
122
123
Tacotron-2 used teacher-forcing when training and evaluation.
124
So we need pass `training=True` for inference step.
125
126
"""
127
outputs = self._model(**batch, training=True)
128
_, dict_metrics_losses = self.compute_per_example_losses(batch, outputs)
129
130
self.update_eval_metrics(dict_metrics_losses)
131
132
def _one_step_predict_per_replica(self, batch):
133
"""One step predict per GPU
134
135
Tacotron-2 used teacher-forcing when training and evaluation.
136
So we need pass `training=True` for inference step.
137
138
"""
139
outputs = self._model(**batch, training=True)
140
return outputs
141
142
def compute_per_example_losses(self, batch, outputs):
143
"""Compute per example losses and return dict_metrics_losses
144
Note that all element of the loss MUST has a shape [batch_size] and
145
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
146
147
Args:
148
batch: dictionary batch input return from dataloader
149
outputs: outputs of the model
150
151
Returns:
152
per_example_losses: per example losses for each GPU, shape [B]
153
dict_metrics_losses: dictionary loss.
154
"""
155
(
156
decoder_output,
157
post_mel_outputs,
158
stop_token_predictions,
159
alignment_historys,
160
) = outputs
161
162
mel_loss_before = calculate_3d_loss(
163
batch["mel_gts"], decoder_output, loss_fn=self.mae
164
)
165
mel_loss_after = calculate_3d_loss(
166
batch["mel_gts"], post_mel_outputs, loss_fn=self.mae
167
)
168
169
# calculate stop_loss
170
max_mel_length = (
171
tf.reduce_max(batch["mel_lengths"])
172
if self.config["use_fixed_shapes"] is False
173
else [self.config["max_mel_length"]]
174
)
175
stop_gts = tf.expand_dims(
176
tf.range(tf.reduce_max(max_mel_length), dtype=tf.int32), 0
177
) # [1, max_len]
178
stop_gts = tf.tile(
179
stop_gts, [tf.shape(batch["mel_lengths"])[0], 1]
180
) # [B, max_len]
181
stop_gts = tf.cast(
182
tf.math.greater_equal(stop_gts, tf.expand_dims(batch["mel_lengths"], 1)),
183
tf.float32,
184
)
185
186
stop_token_loss = calculate_2d_loss(
187
stop_gts, stop_token_predictions, loss_fn=self.binary_crossentropy
188
)
189
190
# calculate guided attention loss.
191
attention_masks = tf.cast(
192
tf.math.not_equal(batch["g_attentions"], -1.0), tf.float32
193
)
194
loss_att = tf.reduce_sum(
195
tf.abs(alignment_historys * batch["g_attentions"]) * attention_masks,
196
axis=[1, 2],
197
)
198
loss_att /= tf.reduce_sum(attention_masks, axis=[1, 2])
199
200
per_example_losses = (
201
stop_token_loss + mel_loss_before + mel_loss_after + loss_att
202
)
203
204
dict_metrics_losses = {
205
"stop_token_loss": stop_token_loss,
206
"mel_loss_before": mel_loss_before,
207
"mel_loss_after": mel_loss_after,
208
"guided_attention_loss": loss_att,
209
}
210
211
return per_example_losses, dict_metrics_losses
212
213
def generate_and_save_intermediate_result(self, batch):
214
"""Generate and save intermediate result."""
215
import matplotlib.pyplot as plt
216
217
# predict with tf.function for faster.
218
outputs = self.one_step_predict(batch)
219
(
220
decoder_output,
221
mel_outputs,
222
stop_token_predictions,
223
alignment_historys,
224
) = outputs
225
mel_gts = batch["mel_gts"]
226
utt_ids = batch["utt_ids"]
227
228
# convert to tensor.
229
# here we just take a sample at first replica.
230
try:
231
mels_before = decoder_output.values[0].numpy()
232
mels_after = mel_outputs.values[0].numpy()
233
mel_gts = mel_gts.values[0].numpy()
234
alignment_historys = alignment_historys.values[0].numpy()
235
utt_ids = utt_ids.values[0].numpy()
236
except Exception:
237
mels_before = decoder_output.numpy()
238
mels_after = mel_outputs.numpy()
239
mel_gts = mel_gts.numpy()
240
alignment_historys = alignment_historys.numpy()
241
utt_ids = utt_ids.numpy()
242
243
# check directory
244
dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")
245
if not os.path.exists(dirname):
246
os.makedirs(dirname)
247
248
for idx, (mel_gt, mel_before, mel_after, alignment_history) in enumerate(
249
zip(mel_gts, mels_before, mels_after, alignment_historys), 0
250
):
251
mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80]
252
mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80]
253
mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80]
254
255
# plot figure and save it
256
utt_id = utt_ids[idx]
257
figname = os.path.join(dirname, f"{utt_id}.png")
258
fig = plt.figure(figsize=(10, 8))
259
ax1 = fig.add_subplot(311)
260
ax2 = fig.add_subplot(312)
261
ax3 = fig.add_subplot(313)
262
im = ax1.imshow(np.rot90(mel_gt), aspect="auto", interpolation="none")
263
ax1.set_title("Target Mel-Spectrogram")
264
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
265
ax2.set_title(f"Predicted Mel-before-Spectrogram @ {self.steps} steps")
266
im = ax2.imshow(np.rot90(mel_before), aspect="auto", interpolation="none")
267
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
268
ax3.set_title(f"Predicted Mel-after-Spectrogram @ {self.steps} steps")
269
im = ax3.imshow(np.rot90(mel_after), aspect="auto", interpolation="none")
270
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax3)
271
plt.tight_layout()
272
plt.savefig(figname)
273
plt.close()
274
275
# plot alignment
276
figname = os.path.join(dirname, f"{idx}_alignment.png")
277
fig = plt.figure(figsize=(8, 6))
278
ax = fig.add_subplot(111)
279
ax.set_title(f"Alignment @ {self.steps} steps")
280
im = ax.imshow(
281
alignment_history, aspect="auto", origin="lower", interpolation="none"
282
)
283
fig.colorbar(im, ax=ax)
284
xlabel = "Decoder timestep"
285
plt.xlabel(xlabel)
286
plt.ylabel("Encoder timestep")
287
plt.tight_layout()
288
plt.savefig(figname)
289
plt.close()
290
291
292
def main():
293
"""Run training process."""
294
parser = argparse.ArgumentParser(
295
description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)"
296
)
297
parser.add_argument(
298
"--train-dir",
299
default=None,
300
type=str,
301
help="directory including training data. ",
302
)
303
parser.add_argument(
304
"--dev-dir",
305
default=None,
306
type=str,
307
help="directory including development data. ",
308
)
309
parser.add_argument(
310
"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."
311
)
312
parser.add_argument(
313
"--outdir", type=str, required=True, help="directory to save checkpoints."
314
)
315
parser.add_argument(
316
"--config", type=str, required=True, help="yaml format configuration file."
317
)
318
parser.add_argument(
319
"--resume",
320
default="",
321
type=str,
322
nargs="?",
323
help='checkpoint file path to resume training. (default="")',
324
)
325
parser.add_argument(
326
"--verbose",
327
type=int,
328
default=1,
329
help="logging level. higher is more logging. (default=1)",
330
)
331
parser.add_argument(
332
"--mixed_precision",
333
default=0,
334
type=int,
335
help="using mixed precision for generator or not.",
336
)
337
parser.add_argument(
338
"--pretrained",
339
default="",
340
type=str,
341
nargs="?",
342
help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",
343
)
344
parser.add_argument(
345
"--use-fal",
346
default=0,
347
type=int,
348
help="Use forced alignment guided attention loss or regular",
349
)
350
args = parser.parse_args()
351
352
# return strategy
353
STRATEGY = return_strategy()
354
355
# set mixed precision config
356
if args.mixed_precision == 1:
357
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
358
359
args.mixed_precision = bool(args.mixed_precision)
360
args.use_norm = bool(args.use_norm)
361
args.use_fal = bool(args.use_fal)
362
363
# set logger
364
if args.verbose > 1:
365
logging.basicConfig(
366
level=logging.DEBUG,
367
stream=sys.stdout,
368
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
369
)
370
elif args.verbose > 0:
371
logging.basicConfig(
372
level=logging.INFO,
373
stream=sys.stdout,
374
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
375
)
376
else:
377
logging.basicConfig(
378
level=logging.WARN,
379
stream=sys.stdout,
380
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
381
)
382
logging.warning("Skip DEBUG/INFO messages")
383
384
# check directory existence
385
if not os.path.exists(args.outdir):
386
os.makedirs(args.outdir)
387
388
# check arguments
389
if args.train_dir is None:
390
raise ValueError("Please specify --train-dir")
391
if args.dev_dir is None:
392
raise ValueError("Please specify --valid-dir")
393
394
# load and save config
395
with open(args.config) as f:
396
config = yaml.load(f, Loader=yaml.Loader)
397
config.update(vars(args))
398
config["version"] = tensorflow_tts.__version__
399
400
# get dataset
401
if config["remove_short_samples"]:
402
mel_length_threshold = config["mel_length_threshold"]
403
else:
404
mel_length_threshold = 0
405
406
if config["format"] == "npy":
407
charactor_query = "*-ids.npy"
408
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
409
align_query = "*-alignment.npy" if args.use_fal is True else ""
410
charactor_load_fn = np.load
411
mel_load_fn = np.load
412
else:
413
raise ValueError("Only npy are supported.")
414
415
train_dataset = CharactorMelDataset(
416
dataset=config["tacotron2_params"]["dataset"],
417
root_dir=args.train_dir,
418
charactor_query=charactor_query,
419
mel_query=mel_query,
420
charactor_load_fn=charactor_load_fn,
421
mel_load_fn=mel_load_fn,
422
mel_length_threshold=mel_length_threshold,
423
reduction_factor=config["tacotron2_params"]["reduction_factor"],
424
use_fixed_shapes=config["use_fixed_shapes"],
425
align_query=align_query,
426
)
427
428
# update max_mel_length and max_char_length to config
429
config.update({"max_mel_length": int(train_dataset.max_mel_length)})
430
config.update({"max_char_length": int(train_dataset.max_char_length)})
431
432
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
433
yaml.dump(config, f, Dumper=yaml.Dumper)
434
for key, value in config.items():
435
logging.info(f"{key} = {value}")
436
437
train_dataset = train_dataset.create(
438
is_shuffle=config["is_shuffle"],
439
allow_cache=config["allow_cache"],
440
batch_size=config["batch_size"]
441
* STRATEGY.num_replicas_in_sync
442
* config["gradient_accumulation_steps"],
443
)
444
445
valid_dataset = CharactorMelDataset(
446
dataset=config["tacotron2_params"]["dataset"],
447
root_dir=args.dev_dir,
448
charactor_query=charactor_query,
449
mel_query=mel_query,
450
charactor_load_fn=charactor_load_fn,
451
mel_load_fn=mel_load_fn,
452
mel_length_threshold=mel_length_threshold,
453
reduction_factor=config["tacotron2_params"]["reduction_factor"],
454
use_fixed_shapes=False, # don't need apply fixed shape for evaluation.
455
align_query=align_query,
456
).create(
457
is_shuffle=config["is_shuffle"],
458
allow_cache=config["allow_cache"],
459
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
460
)
461
462
# define trainer
463
trainer = Tacotron2Trainer(
464
config=config,
465
strategy=STRATEGY,
466
steps=0,
467
epochs=0,
468
is_mixed_precision=args.mixed_precision,
469
)
470
471
with STRATEGY.scope():
472
# define model.
473
tacotron_config = Tacotron2Config(**config["tacotron2_params"])
474
tacotron2 = TFTacotron2(config=tacotron_config, name="tacotron2")
475
tacotron2._build()
476
tacotron2.summary()
477
478
if len(args.pretrained) > 1:
479
tacotron2.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
480
logging.info(
481
f"Successfully loaded pretrained weight from {args.pretrained}."
482
)
483
484
# AdamW for tacotron2
485
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
486
initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],
487
decay_steps=config["optimizer_params"]["decay_steps"],
488
end_learning_rate=config["optimizer_params"]["end_learning_rate"],
489
)
490
491
learning_rate_fn = WarmUp(
492
initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],
493
decay_schedule_fn=learning_rate_fn,
494
warmup_steps=int(
495
config["train_max_steps"]
496
* config["optimizer_params"]["warmup_proportion"]
497
),
498
)
499
500
optimizer = AdamWeightDecay(
501
learning_rate=learning_rate_fn,
502
weight_decay_rate=config["optimizer_params"]["weight_decay"],
503
beta_1=0.9,
504
beta_2=0.98,
505
epsilon=1e-6,
506
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
507
)
508
509
_ = optimizer.iterations
510
511
# compile trainer
512
trainer.compile(model=tacotron2, optimizer=optimizer)
513
514
# start training
515
try:
516
trainer.fit(
517
train_dataset,
518
valid_dataset,
519
saved_path=os.path.join(config["outdir"], "checkpoints/"),
520
resume=args.resume,
521
)
522
except KeyboardInterrupt:
523
trainer.save_checkpoint()
524
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
525
526
527
if __name__ == "__main__":
528
main()
529
530