Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/hifigan/train_hifigan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Train Hifigan."""
16
17
import tensorflow as tf
18
19
physical_devices = tf.config.list_physical_devices("GPU")
20
for i in range(len(physical_devices)):
21
tf.config.experimental.set_memory_growth(physical_devices[i], True)
22
23
import sys
24
25
sys.path.append(".")
26
27
import argparse
28
import logging
29
import os
30
31
import numpy as np
32
import soundfile as sf
33
import yaml
34
from tqdm import tqdm
35
36
import tensorflow_tts
37
from examples.melgan.audio_mel_dataset import AudioMelDataset
38
from examples.melgan.train_melgan import collater
39
from examples.melgan_stft.train_melgan_stft import MultiSTFTMelganTrainer
40
from tensorflow_tts.configs import (
41
HifiGANDiscriminatorConfig,
42
HifiGANGeneratorConfig,
43
MelGANDiscriminatorConfig,
44
)
45
from tensorflow_tts.models import (
46
TFHifiGANGenerator,
47
TFHifiGANMultiPeriodDiscriminator,
48
TFMelGANMultiScaleDiscriminator,
49
)
50
from tensorflow_tts.utils import return_strategy
51
52
53
class TFHifiGANDiscriminator(tf.keras.Model):
54
def __init__(self, multiperiod_dis, multiscale_dis, **kwargs):
55
super().__init__(**kwargs)
56
self.multiperiod_dis = multiperiod_dis
57
self.multiscale_dis = multiscale_dis
58
59
def call(self, x):
60
outs = []
61
period_outs = self.multiperiod_dis(x)
62
scale_outs = self.multiscale_dis(x)
63
outs.extend(period_outs)
64
outs.extend(scale_outs)
65
return outs
66
67
68
def main():
69
"""Run training process."""
70
parser = argparse.ArgumentParser(
71
description="Train Hifigan (See detail in examples/hifigan/train_hifigan.py)"
72
)
73
parser.add_argument(
74
"--train-dir",
75
default=None,
76
type=str,
77
help="directory including training data. ",
78
)
79
parser.add_argument(
80
"--dev-dir",
81
default=None,
82
type=str,
83
help="directory including development data. ",
84
)
85
parser.add_argument(
86
"--use-norm", default=1, type=int, help="use norm mels for training or raw."
87
)
88
parser.add_argument(
89
"--outdir", type=str, required=True, help="directory to save checkpoints."
90
)
91
parser.add_argument(
92
"--config", type=str, required=True, help="yaml format configuration file."
93
)
94
parser.add_argument(
95
"--resume",
96
default="",
97
type=str,
98
nargs="?",
99
help='checkpoint file path to resume training. (default="")',
100
)
101
parser.add_argument(
102
"--verbose",
103
type=int,
104
default=1,
105
help="logging level. higher is more logging. (default=1)",
106
)
107
parser.add_argument(
108
"--generator_mixed_precision",
109
default=0,
110
type=int,
111
help="using mixed precision for generator or not.",
112
)
113
parser.add_argument(
114
"--discriminator_mixed_precision",
115
default=0,
116
type=int,
117
help="using mixed precision for discriminator or not.",
118
)
119
parser.add_argument(
120
"--pretrained",
121
default="",
122
type=str,
123
nargs="?",
124
help="path of .h5 melgan generator to load weights from",
125
)
126
args = parser.parse_args()
127
128
# return strategy
129
STRATEGY = return_strategy()
130
131
# set mixed precision config
132
if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:
133
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
134
135
args.generator_mixed_precision = bool(args.generator_mixed_precision)
136
args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)
137
138
args.use_norm = bool(args.use_norm)
139
140
# set logger
141
if args.verbose > 1:
142
logging.basicConfig(
143
level=logging.DEBUG,
144
stream=sys.stdout,
145
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
146
)
147
elif args.verbose > 0:
148
logging.basicConfig(
149
level=logging.INFO,
150
stream=sys.stdout,
151
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
152
)
153
else:
154
logging.basicConfig(
155
level=logging.WARN,
156
stream=sys.stdout,
157
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
158
)
159
logging.warning("Skip DEBUG/INFO messages")
160
161
# check directory existence
162
if not os.path.exists(args.outdir):
163
os.makedirs(args.outdir)
164
165
# check arguments
166
if args.train_dir is None:
167
raise ValueError("Please specify --train-dir")
168
if args.dev_dir is None:
169
raise ValueError("Please specify either --valid-dir")
170
171
# load and save config
172
with open(args.config) as f:
173
config = yaml.load(f, Loader=yaml.Loader)
174
config.update(vars(args))
175
config["version"] = tensorflow_tts.__version__
176
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
177
yaml.dump(config, f, Dumper=yaml.Dumper)
178
for key, value in config.items():
179
logging.info(f"{key} = {value}")
180
181
# get dataset
182
if config["remove_short_samples"]:
183
mel_length_threshold = config["batch_max_steps"] // config[
184
"hop_size"
185
] + 2 * config["hifigan_generator_params"].get("aux_context_window", 0)
186
else:
187
mel_length_threshold = None
188
189
if config["format"] == "npy":
190
audio_query = "*-wave.npy"
191
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
192
audio_load_fn = np.load
193
mel_load_fn = np.load
194
else:
195
raise ValueError("Only npy are supported.")
196
197
# define train/valid dataset
198
train_dataset = AudioMelDataset(
199
root_dir=args.train_dir,
200
audio_query=audio_query,
201
mel_query=mel_query,
202
audio_load_fn=audio_load_fn,
203
mel_load_fn=mel_load_fn,
204
mel_length_threshold=mel_length_threshold,
205
).create(
206
is_shuffle=config["is_shuffle"],
207
map_fn=lambda items: collater(
208
items,
209
batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),
210
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
211
),
212
allow_cache=config["allow_cache"],
213
batch_size=config["batch_size"]
214
* STRATEGY.num_replicas_in_sync
215
* config["gradient_accumulation_steps"],
216
)
217
218
valid_dataset = AudioMelDataset(
219
root_dir=args.dev_dir,
220
audio_query=audio_query,
221
mel_query=mel_query,
222
audio_load_fn=audio_load_fn,
223
mel_load_fn=mel_load_fn,
224
mel_length_threshold=mel_length_threshold,
225
).create(
226
is_shuffle=config["is_shuffle"],
227
map_fn=lambda items: collater(
228
items,
229
batch_max_steps=tf.constant(
230
config["batch_max_steps_valid"], dtype=tf.int32
231
),
232
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
233
),
234
allow_cache=config["allow_cache"],
235
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
236
)
237
238
# define trainer
239
trainer = MultiSTFTMelganTrainer(
240
steps=0,
241
epochs=0,
242
config=config,
243
strategy=STRATEGY,
244
is_generator_mixed_precision=args.generator_mixed_precision,
245
is_discriminator_mixed_precision=args.discriminator_mixed_precision,
246
)
247
248
with STRATEGY.scope():
249
# define generator and discriminator
250
generator = TFHifiGANGenerator(
251
HifiGANGeneratorConfig(**config["hifigan_generator_params"]),
252
name="hifigan_generator",
253
)
254
255
multiperiod_discriminator = TFHifiGANMultiPeriodDiscriminator(
256
HifiGANDiscriminatorConfig(**config["hifigan_discriminator_params"]),
257
name="hifigan_multiperiod_discriminator",
258
)
259
multiscale_discriminator = TFMelGANMultiScaleDiscriminator(
260
MelGANDiscriminatorConfig(
261
**config["melgan_discriminator_params"],
262
name="melgan_multiscale_discriminator",
263
)
264
)
265
266
discriminator = TFHifiGANDiscriminator(
267
multiperiod_discriminator,
268
multiscale_discriminator,
269
name="hifigan_discriminator",
270
)
271
272
# dummy input to build model.
273
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
274
y_hat = generator(fake_mels)
275
discriminator(y_hat)
276
277
if len(args.pretrained) > 1:
278
generator.load_weights(args.pretrained)
279
logging.info(
280
f"Successfully loaded pretrained weight from {args.pretrained}."
281
)
282
283
generator.summary()
284
discriminator.summary()
285
286
# define optimizer
287
generator_lr_fn = getattr(
288
tf.keras.optimizers.schedules, config["generator_optimizer_params"]["lr_fn"]
289
)(**config["generator_optimizer_params"]["lr_params"])
290
discriminator_lr_fn = getattr(
291
tf.keras.optimizers.schedules,
292
config["discriminator_optimizer_params"]["lr_fn"],
293
)(**config["discriminator_optimizer_params"]["lr_params"])
294
295
gen_optimizer = tf.keras.optimizers.Adam(
296
learning_rate=generator_lr_fn,
297
amsgrad=config["generator_optimizer_params"]["amsgrad"],
298
)
299
dis_optimizer = tf.keras.optimizers.Adam(
300
learning_rate=discriminator_lr_fn,
301
amsgrad=config["discriminator_optimizer_params"]["amsgrad"],
302
)
303
304
trainer.compile(
305
gen_model=generator,
306
dis_model=discriminator,
307
gen_optimizer=gen_optimizer,
308
dis_optimizer=dis_optimizer,
309
)
310
311
# start training
312
try:
313
trainer.fit(
314
train_dataset,
315
valid_dataset,
316
saved_path=os.path.join(config["outdir"], "checkpoints/"),
317
resume=args.resume,
318
)
319
except KeyboardInterrupt:
320
trainer.save_checkpoint()
321
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
322
323
324
if __name__ == "__main__":
325
main()
326
327