Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/melgan_stft/train_melgan_stft.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Train MelGAN Multi Resolution STFT Loss."""
16
17
import tensorflow as tf
18
19
physical_devices = tf.config.list_physical_devices("GPU")
20
for i in range(len(physical_devices)):
21
tf.config.experimental.set_memory_growth(physical_devices[i], True)
22
23
import sys
24
25
sys.path.append(".")
26
27
import argparse
28
import logging
29
import os
30
31
import numpy as np
32
import yaml
33
34
import tensorflow_tts
35
import tensorflow_tts.configs.melgan as MELGAN_CONFIG
36
from examples.melgan.audio_mel_dataset import AudioMelDataset
37
from examples.melgan.train_melgan import MelganTrainer, collater
38
from tensorflow_tts.losses import TFMultiResolutionSTFT
39
from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator
40
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
41
42
43
class MultiSTFTMelganTrainer(MelganTrainer):
44
"""Multi STFT Melgan Trainer class based on MelganTrainer."""
45
46
def __init__(
47
self,
48
config,
49
strategy,
50
steps=0,
51
epochs=0,
52
is_generator_mixed_precision=False,
53
is_discriminator_mixed_precision=False,
54
):
55
"""Initialize trainer.
56
57
Args:
58
steps (int): Initial global steps.
59
epochs (int): Initial global epochs.
60
config (dict): Config dict loaded from yaml format configuration file.
61
is_generator_mixed_precision (bool): Use mixed precision for generator or not.
62
is_discriminator_mixed_precision (bool): Use mixed precision for discriminator or not.
63
64
"""
65
super(MultiSTFTMelganTrainer, self).__init__(
66
config=config,
67
steps=steps,
68
epochs=epochs,
69
strategy=strategy,
70
is_generator_mixed_precision=is_generator_mixed_precision,
71
is_discriminator_mixed_precision=is_discriminator_mixed_precision,
72
)
73
74
self.list_metrics_name = [
75
"adversarial_loss",
76
"fm_loss",
77
"gen_loss",
78
"real_loss",
79
"fake_loss",
80
"dis_loss",
81
"spectral_convergence_loss",
82
"log_magnitude_loss",
83
]
84
85
self.init_train_eval_metrics(self.list_metrics_name)
86
self.reset_states_train()
87
self.reset_states_eval()
88
89
def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer):
90
super().compile(gen_model, dis_model, gen_optimizer, dis_optimizer)
91
# define loss
92
self.stft_loss = TFMultiResolutionSTFT(**self.config["stft_loss_params"])
93
94
def compute_per_example_generator_losses(self, batch, outputs):
95
"""Compute per example generator losses and return dict_metrics_losses
96
Note that all element of the loss MUST has a shape [batch_size] and
97
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
98
99
Args:
100
batch: dictionary batch input return from dataloader
101
outputs: outputs of the model
102
103
Returns:
104
per_example_losses: per example losses for each GPU, shape [B]
105
dict_metrics_losses: dictionary loss.
106
"""
107
dict_metrics_losses = {}
108
per_example_losses = 0.0
109
110
audios = batch["audios"]
111
y_hat = outputs
112
113
# calculate multi-resolution stft loss
114
sc_loss, mag_loss = calculate_2d_loss(
115
audios, tf.squeeze(y_hat, -1), self.stft_loss
116
)
117
118
# trick to prevent loss expoded here
119
sc_loss = tf.where(sc_loss >= 15.0, 0.0, sc_loss)
120
mag_loss = tf.where(mag_loss >= 15.0, 0.0, mag_loss)
121
122
# compute generator loss
123
gen_loss = 0.5 * (sc_loss + mag_loss)
124
125
if self.steps >= self.config["discriminator_train_start_steps"]:
126
p_hat = self._discriminator(y_hat)
127
p = self._discriminator(tf.expand_dims(audios, 2))
128
adv_loss = 0.0
129
for i in range(len(p_hat)):
130
adv_loss += calculate_3d_loss(
131
tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss
132
)
133
adv_loss /= i + 1
134
135
# define feature-matching loss
136
fm_loss = 0.0
137
for i in range(len(p_hat)):
138
for j in range(len(p_hat[i]) - 1):
139
fm_loss += calculate_3d_loss(
140
p[i][j], p_hat[i][j], loss_fn=self.mae_loss
141
)
142
fm_loss /= (i + 1) * (j + 1)
143
adv_loss += self.config["lambda_feat_match"] * fm_loss
144
gen_loss += self.config["lambda_adv"] * adv_loss
145
146
dict_metrics_losses.update({"adversarial_loss": adv_loss})
147
dict_metrics_losses.update({"fm_loss": fm_loss})
148
149
dict_metrics_losses.update({"gen_loss": gen_loss})
150
dict_metrics_losses.update({"spectral_convergence_loss": sc_loss})
151
dict_metrics_losses.update({"log_magnitude_loss": mag_loss})
152
153
per_example_losses = gen_loss
154
return per_example_losses, dict_metrics_losses
155
156
157
def main():
158
"""Run training process."""
159
parser = argparse.ArgumentParser(
160
description="Train MelGAN (See detail in tensorflow_tts/bin/train-melgan.py)"
161
)
162
parser.add_argument(
163
"--train-dir",
164
default=None,
165
type=str,
166
help="directory including training data. ",
167
)
168
parser.add_argument(
169
"--dev-dir",
170
default=None,
171
type=str,
172
help="directory including development data. ",
173
)
174
parser.add_argument(
175
"--use-norm", default=1, type=int, help="use norm mels for training or raw."
176
)
177
parser.add_argument(
178
"--outdir", type=str, required=True, help="directory to save checkpoints."
179
)
180
parser.add_argument(
181
"--config", type=str, required=True, help="yaml format configuration file."
182
)
183
parser.add_argument(
184
"--resume",
185
default="",
186
type=str,
187
nargs="?",
188
help='checkpoint file path to resume training. (default="")',
189
)
190
parser.add_argument(
191
"--verbose",
192
type=int,
193
default=1,
194
help="logging level. higher is more logging. (default=1)",
195
)
196
parser.add_argument(
197
"--generator_mixed_precision",
198
default=0,
199
type=int,
200
help="using mixed precision for generator or not.",
201
)
202
parser.add_argument(
203
"--discriminator_mixed_precision",
204
default=0,
205
type=int,
206
help="using mixed precision for discriminator or not.",
207
)
208
parser.add_argument(
209
"--pretrained",
210
default="",
211
type=str,
212
nargs="?",
213
help="path of .h5 melgan generator to load weights from",
214
)
215
args = parser.parse_args()
216
217
# return strategy
218
STRATEGY = return_strategy()
219
220
# set mixed precision config
221
if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:
222
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
223
224
args.generator_mixed_precision = bool(args.generator_mixed_precision)
225
args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)
226
227
args.use_norm = bool(args.use_norm)
228
229
# set logger
230
if args.verbose > 1:
231
logging.basicConfig(
232
level=logging.DEBUG,
233
stream=sys.stdout,
234
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
235
)
236
elif args.verbose > 0:
237
logging.basicConfig(
238
level=logging.INFO,
239
stream=sys.stdout,
240
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
241
)
242
else:
243
logging.basicConfig(
244
level=logging.WARN,
245
stream=sys.stdout,
246
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
247
)
248
logging.warning("Skip DEBUG/INFO messages")
249
250
# check directory existence
251
if not os.path.exists(args.outdir):
252
os.makedirs(args.outdir)
253
254
# check arguments
255
if args.train_dir is None:
256
raise ValueError("Please specify --train-dir")
257
if args.dev_dir is None:
258
raise ValueError("Please specify either --valid-dir")
259
260
# load and save config
261
with open(args.config) as f:
262
config = yaml.load(f, Loader=yaml.Loader)
263
config.update(vars(args))
264
config["version"] = tensorflow_tts.__version__
265
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
266
yaml.dump(config, f, Dumper=yaml.Dumper)
267
for key, value in config.items():
268
logging.info(f"{key} = {value}")
269
270
# get dataset
271
if config["remove_short_samples"]:
272
mel_length_threshold = config["batch_max_steps"] // config[
273
"hop_size"
274
] + 2 * config["melgan_generator_params"].get("aux_context_window", 0)
275
else:
276
mel_length_threshold = None
277
278
if config["format"] == "npy":
279
audio_query = "*-wave.npy"
280
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
281
audio_load_fn = np.load
282
mel_load_fn = np.load
283
else:
284
raise ValueError("Only npy are supported.")
285
286
# define train/valid dataset
287
train_dataset = AudioMelDataset(
288
root_dir=args.train_dir,
289
audio_query=audio_query,
290
mel_query=mel_query,
291
audio_load_fn=audio_load_fn,
292
mel_load_fn=mel_load_fn,
293
mel_length_threshold=mel_length_threshold,
294
).create(
295
is_shuffle=config["is_shuffle"],
296
map_fn=lambda items: collater(
297
items,
298
batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),
299
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
300
),
301
allow_cache=config["allow_cache"],
302
batch_size=config["batch_size"]
303
* STRATEGY.num_replicas_in_sync
304
* config["gradient_accumulation_steps"],
305
)
306
307
valid_dataset = AudioMelDataset(
308
root_dir=args.dev_dir,
309
audio_query=audio_query,
310
mel_query=mel_query,
311
audio_load_fn=audio_load_fn,
312
mel_load_fn=mel_load_fn,
313
mel_length_threshold=mel_length_threshold,
314
).create(
315
is_shuffle=config["is_shuffle"],
316
map_fn=lambda items: collater(
317
items,
318
batch_max_steps=tf.constant(
319
config["batch_max_steps_valid"], dtype=tf.int32
320
),
321
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
322
),
323
allow_cache=config["allow_cache"],
324
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
325
)
326
327
# define trainer
328
trainer = MultiSTFTMelganTrainer(
329
steps=0,
330
epochs=0,
331
config=config,
332
strategy=STRATEGY,
333
is_generator_mixed_precision=args.generator_mixed_precision,
334
is_discriminator_mixed_precision=args.discriminator_mixed_precision,
335
)
336
337
with STRATEGY.scope():
338
# define generator and discriminator
339
generator = TFMelGANGenerator(
340
MELGAN_CONFIG.MelGANGeneratorConfig(**config["melgan_generator_params"]),
341
name="melgan_generator",
342
)
343
344
discriminator = TFMelGANMultiScaleDiscriminator(
345
MELGAN_CONFIG.MelGANDiscriminatorConfig(
346
**config["melgan_discriminator_params"]
347
),
348
name="melgan_discriminator",
349
)
350
351
# dummy input to build model.
352
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
353
y_hat = generator(fake_mels)
354
discriminator(y_hat)
355
356
if len(args.pretrained) > 1:
357
generator.load_weights(args.pretrained)
358
logging.info(
359
f"Successfully loaded pretrained weight from {args.pretrained}."
360
)
361
362
generator.summary()
363
discriminator.summary()
364
365
# define optimizer
366
generator_lr_fn = getattr(
367
tf.keras.optimizers.schedules, config["generator_optimizer_params"]["lr_fn"]
368
)(**config["generator_optimizer_params"]["lr_params"])
369
discriminator_lr_fn = getattr(
370
tf.keras.optimizers.schedules,
371
config["discriminator_optimizer_params"]["lr_fn"],
372
)(**config["discriminator_optimizer_params"]["lr_params"])
373
374
gen_optimizer = tf.keras.optimizers.Adam(
375
learning_rate=generator_lr_fn, amsgrad=False
376
)
377
dis_optimizer = tf.keras.optimizers.Adam(
378
learning_rate=discriminator_lr_fn, amsgrad=False
379
)
380
381
trainer.compile(
382
gen_model=generator,
383
dis_model=discriminator,
384
gen_optimizer=gen_optimizer,
385
dis_optimizer=dis_optimizer,
386
)
387
388
# start training
389
try:
390
trainer.fit(
391
train_dataset,
392
valid_dataset,
393
saved_path=os.path.join(config["outdir"], "checkpoints/"),
394
resume=args.resume,
395
)
396
except KeyboardInterrupt:
397
trainer.save_checkpoint()
398
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
399
400
401
if __name__ == "__main__":
402
main()
403
404