Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/fastspeech2/train_fastspeech2.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Train FastSpeech2."""
16
17
import tensorflow as tf
18
19
physical_devices = tf.config.list_physical_devices("GPU")
20
for i in range(len(physical_devices)):
21
tf.config.experimental.set_memory_growth(physical_devices[i], True)
22
23
import sys
24
25
sys.path.append(".")
26
27
import argparse
28
import logging
29
import os
30
31
import numpy as np
32
import yaml
33
from tqdm import tqdm
34
35
import tensorflow_tts
36
from examples.fastspeech2.fastspeech2_dataset import CharactorDurationF0EnergyMelDataset
37
from examples.fastspeech.train_fastspeech import FastSpeechTrainer
38
from tensorflow_tts.configs import FastSpeech2Config
39
from tensorflow_tts.models import TFFastSpeech2
40
from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp
41
from tensorflow_tts.trainers import Seq2SeqBasedTrainer
42
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
43
44
45
class FastSpeech2Trainer(Seq2SeqBasedTrainer):
46
"""FastSpeech2 Trainer class based on FastSpeechTrainer."""
47
48
def __init__(
49
self, config, strategy, steps=0, epochs=0, is_mixed_precision=False,
50
):
51
"""Initialize trainer.
52
Args:
53
steps (int): Initial global steps.
54
epochs (int): Initial global epochs.
55
config (dict): Config dict loaded from yaml format configuration file.
56
is_mixed_precision (bool): Use mixed precision or not.
57
"""
58
super(FastSpeech2Trainer, self).__init__(
59
steps=steps,
60
epochs=epochs,
61
config=config,
62
strategy=strategy,
63
is_mixed_precision=is_mixed_precision,
64
)
65
# define metrics to aggregates data and use tf.summary logs them
66
self.list_metrics_name = [
67
"duration_loss",
68
"f0_loss",
69
"energy_loss",
70
"mel_loss_before",
71
"mel_loss_after",
72
]
73
self.init_train_eval_metrics(self.list_metrics_name)
74
self.reset_states_train()
75
self.reset_states_eval()
76
77
def compile(self, model, optimizer):
78
super().compile(model, optimizer)
79
self.mse = tf.keras.losses.MeanSquaredError(
80
reduction=tf.keras.losses.Reduction.NONE
81
)
82
self.mae = tf.keras.losses.MeanAbsoluteError(
83
reduction=tf.keras.losses.Reduction.NONE
84
)
85
86
def compute_per_example_losses(self, batch, outputs):
87
"""Compute per example losses and return dict_metrics_losses
88
Note that all element of the loss MUST has a shape [batch_size] and
89
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
90
91
Args:
92
batch: dictionary batch input return from dataloader
93
outputs: outputs of the model
94
95
Returns:
96
per_example_losses: per example losses for each GPU, shape [B]
97
dict_metrics_losses: dictionary loss.
98
"""
99
mel_before, mel_after, duration_outputs, f0_outputs, energy_outputs = outputs
100
101
log_duration = tf.math.log(
102
tf.cast(tf.math.add(batch["duration_gts"], 1), tf.float32)
103
)
104
duration_loss = calculate_2d_loss(log_duration, duration_outputs, self.mse)
105
f0_loss = calculate_2d_loss(batch["f0_gts"], f0_outputs, self.mse)
106
energy_loss = calculate_2d_loss(batch["energy_gts"], energy_outputs, self.mse)
107
mel_loss_before = calculate_3d_loss(batch["mel_gts"], mel_before, self.mae)
108
mel_loss_after = calculate_3d_loss(batch["mel_gts"], mel_after, self.mae)
109
110
per_example_losses = (
111
duration_loss + f0_loss + energy_loss + mel_loss_before + mel_loss_after
112
)
113
114
dict_metrics_losses = {
115
"duration_loss": duration_loss,
116
"f0_loss": f0_loss,
117
"energy_loss": energy_loss,
118
"mel_loss_before": mel_loss_before,
119
"mel_loss_after": mel_loss_after,
120
}
121
122
return per_example_losses, dict_metrics_losses
123
124
def generate_and_save_intermediate_result(self, batch):
125
"""Generate and save intermediate result."""
126
import matplotlib.pyplot as plt
127
128
# predict with tf.function.
129
outputs = self.one_step_predict(batch)
130
131
mels_before, mels_after, *_ = outputs
132
mel_gts = batch["mel_gts"]
133
utt_ids = batch["utt_ids"]
134
135
# convert to tensor.
136
# here we just take a sample at first replica.
137
try:
138
mels_before = mels_before.values[0].numpy()
139
mels_after = mels_after.values[0].numpy()
140
mel_gts = mel_gts.values[0].numpy()
141
utt_ids = utt_ids.values[0].numpy()
142
except Exception:
143
mels_before = mels_before.numpy()
144
mels_after = mels_after.numpy()
145
mel_gts = mel_gts.numpy()
146
utt_ids = utt_ids.numpy()
147
148
# check directory
149
dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")
150
if not os.path.exists(dirname):
151
os.makedirs(dirname)
152
153
for idx, (mel_gt, mel_before, mel_after) in enumerate(
154
zip(mel_gts, mels_before, mels_after), 0
155
):
156
mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80]
157
mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80]
158
mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80]
159
160
# plit figure and save it
161
utt_id = utt_ids[idx]
162
figname = os.path.join(dirname, f"{utt_id}.png")
163
fig = plt.figure(figsize=(10, 8))
164
ax1 = fig.add_subplot(311)
165
ax2 = fig.add_subplot(312)
166
ax3 = fig.add_subplot(313)
167
im = ax1.imshow(np.rot90(mel_gt), aspect="auto", interpolation="none")
168
ax1.set_title("Target Mel-Spectrogram")
169
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
170
ax2.set_title("Predicted Mel-before-Spectrogram")
171
im = ax2.imshow(np.rot90(mel_before), aspect="auto", interpolation="none")
172
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
173
ax3.set_title("Predicted Mel-after-Spectrogram")
174
im = ax3.imshow(np.rot90(mel_after), aspect="auto", interpolation="none")
175
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax3)
176
plt.tight_layout()
177
plt.savefig(figname)
178
plt.close()
179
180
181
def main():
182
"""Run training process."""
183
parser = argparse.ArgumentParser(
184
description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)"
185
)
186
parser.add_argument(
187
"--train-dir",
188
default=None,
189
type=str,
190
help="directory including training data. ",
191
)
192
parser.add_argument(
193
"--dev-dir",
194
default=None,
195
type=str,
196
help="directory including development data. ",
197
)
198
parser.add_argument(
199
"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."
200
)
201
parser.add_argument(
202
"--f0-stat",
203
default="./dump/stats_f0.npy",
204
type=str,
205
required=True,
206
help="f0-stat path.",
207
)
208
parser.add_argument(
209
"--energy-stat",
210
default="./dump/stats_energy.npy",
211
type=str,
212
required=True,
213
help="energy-stat path.",
214
)
215
parser.add_argument(
216
"--outdir", type=str, required=True, help="directory to save checkpoints."
217
)
218
parser.add_argument(
219
"--config", type=str, required=True, help="yaml format configuration file."
220
)
221
parser.add_argument(
222
"--resume",
223
default="",
224
type=str,
225
nargs="?",
226
help='checkpoint file path to resume training. (default="")',
227
)
228
parser.add_argument(
229
"--verbose",
230
type=int,
231
default=1,
232
help="logging level. higher is more logging. (default=1)",
233
)
234
parser.add_argument(
235
"--mixed_precision",
236
default=0,
237
type=int,
238
help="using mixed precision for generator or not.",
239
)
240
parser.add_argument(
241
"--pretrained",
242
default="",
243
type=str,
244
nargs="?",
245
help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",
246
)
247
248
args = parser.parse_args()
249
250
# return strategy
251
STRATEGY = return_strategy()
252
253
# set mixed precision config
254
if args.mixed_precision == 1:
255
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
256
257
args.mixed_precision = bool(args.mixed_precision)
258
args.use_norm = bool(args.use_norm)
259
260
# set logger
261
if args.verbose > 1:
262
logging.basicConfig(
263
level=logging.DEBUG,
264
stream=sys.stdout,
265
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
266
)
267
elif args.verbose > 0:
268
logging.basicConfig(
269
level=logging.INFO,
270
stream=sys.stdout,
271
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
272
)
273
else:
274
logging.basicConfig(
275
level=logging.WARN,
276
stream=sys.stdout,
277
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
278
)
279
logging.warning("Skip DEBUG/INFO messages")
280
281
# check directory existence
282
if not os.path.exists(args.outdir):
283
os.makedirs(args.outdir)
284
285
# check arguments
286
if args.train_dir is None:
287
raise ValueError("Please specify --train-dir")
288
if args.dev_dir is None:
289
raise ValueError("Please specify --valid-dir")
290
291
# load and save config
292
with open(args.config) as f:
293
config = yaml.load(f, Loader=yaml.Loader)
294
config.update(vars(args))
295
config["version"] = tensorflow_tts.__version__
296
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
297
yaml.dump(config, f, Dumper=yaml.Dumper)
298
for key, value in config.items():
299
logging.info(f"{key} = {value}")
300
301
# get dataset
302
if config["remove_short_samples"]:
303
mel_length_threshold = config["mel_length_threshold"]
304
else:
305
mel_length_threshold = None
306
307
if config["format"] == "npy":
308
charactor_query = "*-ids.npy"
309
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
310
duration_query = "*-durations.npy"
311
f0_query = "*-raw-f0.npy"
312
energy_query = "*-raw-energy.npy"
313
else:
314
raise ValueError("Only npy are supported.")
315
316
# define train/valid dataset
317
train_dataset = CharactorDurationF0EnergyMelDataset(
318
root_dir=args.train_dir,
319
charactor_query=charactor_query,
320
mel_query=mel_query,
321
duration_query=duration_query,
322
f0_query=f0_query,
323
energy_query=energy_query,
324
f0_stat=args.f0_stat,
325
energy_stat=args.energy_stat,
326
mel_length_threshold=mel_length_threshold,
327
).create(
328
is_shuffle=config["is_shuffle"],
329
allow_cache=config["allow_cache"],
330
batch_size=config["batch_size"]
331
* STRATEGY.num_replicas_in_sync
332
* config["gradient_accumulation_steps"],
333
)
334
335
valid_dataset = CharactorDurationF0EnergyMelDataset(
336
root_dir=args.dev_dir,
337
charactor_query=charactor_query,
338
mel_query=mel_query,
339
duration_query=duration_query,
340
f0_query=f0_query,
341
energy_query=energy_query,
342
f0_stat=args.f0_stat,
343
energy_stat=args.energy_stat,
344
mel_length_threshold=mel_length_threshold,
345
).create(
346
is_shuffle=config["is_shuffle"],
347
allow_cache=config["allow_cache"],
348
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
349
)
350
351
# define trainer
352
trainer = FastSpeech2Trainer(
353
config=config,
354
strategy=STRATEGY,
355
steps=0,
356
epochs=0,
357
is_mixed_precision=args.mixed_precision,
358
)
359
360
with STRATEGY.scope():
361
# define model
362
fastspeech = TFFastSpeech2(
363
config=FastSpeech2Config(**config["fastspeech2_params"])
364
)
365
fastspeech._build()
366
fastspeech.summary()
367
if len(args.pretrained) > 1:
368
fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
369
logging.info(
370
f"Successfully loaded pretrained weight from {args.pretrained}."
371
)
372
373
# AdamW for fastspeech
374
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
375
initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],
376
decay_steps=config["optimizer_params"]["decay_steps"],
377
end_learning_rate=config["optimizer_params"]["end_learning_rate"],
378
)
379
380
learning_rate_fn = WarmUp(
381
initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],
382
decay_schedule_fn=learning_rate_fn,
383
warmup_steps=int(
384
config["train_max_steps"]
385
* config["optimizer_params"]["warmup_proportion"]
386
),
387
)
388
389
optimizer = AdamWeightDecay(
390
learning_rate=learning_rate_fn,
391
weight_decay_rate=config["optimizer_params"]["weight_decay"],
392
beta_1=0.9,
393
beta_2=0.98,
394
epsilon=1e-6,
395
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
396
)
397
398
_ = optimizer.iterations
399
400
# compile trainer
401
trainer.compile(model=fastspeech, optimizer=optimizer)
402
403
# start training
404
try:
405
trainer.fit(
406
train_dataset,
407
valid_dataset,
408
saved_path=os.path.join(config["outdir"], "checkpoints/"),
409
resume=args.resume,
410
)
411
except KeyboardInterrupt:
412
trainer.save_checkpoint()
413
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
414
415
416
if __name__ == "__main__":
417
main()
418
419