Path: blob/master/examples/melgan/train_melgan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Train MelGAN."""1516import tensorflow as tf1718physical_devices = tf.config.list_physical_devices("GPU")19for i in range(len(physical_devices)):20tf.config.experimental.set_memory_growth(physical_devices[i], True)2122import sys2324sys.path.append(".")2526import argparse27import logging28import os2930import numpy as np31import soundfile as sf32import yaml33from tqdm import tqdm3435import tensorflow_tts36import tensorflow_tts.configs.melgan as MELGAN_CONFIG37from examples.melgan.audio_mel_dataset import AudioMelDataset38from tensorflow_tts.losses import TFMelSpectrogram39from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator40from tensorflow_tts.trainers import GanBasedTrainer41from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy424344class MelganTrainer(GanBasedTrainer):45"""Melgan Trainer class based on GanBasedTrainer."""4647def __init__(48self,49config,50strategy,51steps=0,52epochs=0,53is_generator_mixed_precision=False,54is_discriminator_mixed_precision=False,55):56"""Initialize trainer.5758Args:59steps (int): Initial global steps.60epochs (int): Initial global epochs.61config (dict): Config dict loaded from yaml format configuration file.62is_generator_mixed_precision (bool): Use mixed precision for generator or not.63is_discriminator_mixed_precision (bool): Use mixed precision for discriminator or not.646566"""67super(MelganTrainer, self).__init__(68steps,69epochs,70config,71strategy,72is_generator_mixed_precision,73is_discriminator_mixed_precision,74)75# define metrics to aggregates data and use tf.summary logs them76self.list_metrics_name = [77"adversarial_loss",78"fm_loss",79"gen_loss",80"real_loss",81"fake_loss",82"dis_loss",83"mels_spectrogram_loss",84]85self.init_train_eval_metrics(self.list_metrics_name)86self.reset_states_train()87self.reset_states_eval()8889self.config = config9091def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer):92super().compile(gen_model, dis_model, gen_optimizer, dis_optimizer)93# define loss94self.mse_loss = tf.keras.losses.MeanSquaredError(95reduction=tf.keras.losses.Reduction.NONE96)97self.mae_loss = tf.keras.losses.MeanAbsoluteError(98reduction=tf.keras.losses.Reduction.NONE99)100self.mels_loss = TFMelSpectrogram()101102def compute_per_example_generator_losses(self, batch, outputs):103"""Compute per example generator losses and return dict_metrics_losses104Note that all element of the loss MUST has a shape [batch_size] and105the keys of dict_metrics_losses MUST be in self.list_metrics_name.106107Args:108batch: dictionary batch input return from dataloader109outputs: outputs of the model110111Returns:112per_example_losses: per example losses for each GPU, shape [B]113dict_metrics_losses: dictionary loss.114"""115audios = batch["audios"]116y_hat = outputs117118p_hat = self._discriminator(y_hat)119p = self._discriminator(tf.expand_dims(audios, 2))120adv_loss = 0.0121for i in range(len(p_hat)):122adv_loss += calculate_3d_loss(123tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss124)125adv_loss /= i + 1126127# define feature-matching loss128fm_loss = 0.0129for i in range(len(p_hat)):130for j in range(len(p_hat[i]) - 1):131fm_loss += calculate_3d_loss(132p[i][j], p_hat[i][j], loss_fn=self.mae_loss133)134fm_loss /= (i + 1) * (j + 1)135adv_loss += self.config["lambda_feat_match"] * fm_loss136137per_example_losses = adv_loss138139dict_metrics_losses = {140"adversarial_loss": adv_loss,141"fm_loss": fm_loss,142"gen_loss": adv_loss,143"mels_spectrogram_loss": calculate_2d_loss(144audios, tf.squeeze(y_hat, -1), loss_fn=self.mels_loss145),146}147148return per_example_losses, dict_metrics_losses149150def compute_per_example_discriminator_losses(self, batch, gen_outputs):151audios = batch["audios"]152y_hat = gen_outputs153154y = tf.expand_dims(audios, 2)155p = self._discriminator(y)156p_hat = self._discriminator(y_hat)157158real_loss = 0.0159fake_loss = 0.0160for i in range(len(p)):161real_loss += calculate_3d_loss(162tf.ones_like(p[i][-1]), p[i][-1], loss_fn=self.mse_loss163)164fake_loss += calculate_3d_loss(165tf.zeros_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss166)167real_loss /= i + 1168fake_loss /= i + 1169dis_loss = real_loss + fake_loss170171# calculate per_example_losses and dict_metrics_losses172per_example_losses = dis_loss173174dict_metrics_losses = {175"real_loss": real_loss,176"fake_loss": fake_loss,177"dis_loss": dis_loss,178}179180return per_example_losses, dict_metrics_losses181182def generate_and_save_intermediate_result(self, batch):183"""Generate and save intermediate result."""184import matplotlib.pyplot as plt185186# generate187y_batch_ = self.one_step_predict(batch)188y_batch = batch["audios"]189utt_ids = batch["utt_ids"]190191# convert to tensor.192# here we just take a sample at first replica.193try:194y_batch_ = y_batch_.values[0].numpy()195y_batch = y_batch.values[0].numpy()196utt_ids = utt_ids.values[0].numpy()197except Exception:198y_batch_ = y_batch_.numpy()199y_batch = y_batch.numpy()200utt_ids = utt_ids.numpy()201202# check directory203dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")204if not os.path.exists(dirname):205os.makedirs(dirname)206207for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 0):208# convert to ndarray209y, y_ = tf.reshape(y, [-1]).numpy(), tf.reshape(y_, [-1]).numpy()210211# plit figure and save it212utt_id = utt_ids[idx]213figname = os.path.join(dirname, f"{utt_id}.png")214plt.subplot(2, 1, 1)215plt.plot(y)216plt.title("groundtruth speech")217plt.subplot(2, 1, 2)218plt.plot(y_)219plt.title(f"generated speech @ {self.steps} steps")220plt.tight_layout()221plt.savefig(figname)222plt.close()223224# save as wavefile225y = np.clip(y, -1, 1)226y_ = np.clip(y_, -1, 1)227sf.write(228figname.replace(".png", "_ref.wav"),229y,230self.config["sampling_rate"],231"PCM_16",232)233sf.write(234figname.replace(".png", "_gen.wav"),235y_,236self.config["sampling_rate"],237"PCM_16",238)239240241def collater(242items,243batch_max_steps=tf.constant(8192, dtype=tf.int32),244hop_size=tf.constant(256, dtype=tf.int32),245):246"""Initialize collater (mapping function) for Tensorflow Audio-Mel Dataset.247248Args:249batch_max_steps (int): The maximum length of input signal in batch.250hop_size (int): Hop size of auxiliary features.251252"""253audio, mel = items["audios"], items["mels"]254255if batch_max_steps is None:256batch_max_steps = (tf.shape(audio)[0] // hop_size) * hop_size257258batch_max_frames = batch_max_steps // hop_size259if len(audio) < len(mel) * hop_size:260audio = tf.pad(audio, [[0, len(mel) * hop_size - len(audio)]])261262if len(mel) > batch_max_frames:263# randomly pickup with the batch_max_steps length of the part264interval_start = 0265interval_end = len(mel) - batch_max_frames266start_frame = tf.random.uniform(267shape=[], minval=interval_start, maxval=interval_end, dtype=tf.int32268)269start_step = start_frame * hop_size270audio = audio[start_step : start_step + batch_max_steps]271mel = mel[start_frame : start_frame + batch_max_frames, :]272else:273audio = tf.pad(audio, [[0, batch_max_steps - len(audio)]])274mel = tf.pad(mel, [[0, batch_max_frames - len(mel)], [0, 0]])275276items = {277"utt_ids": items["utt_ids"],278"audios": audio,279"mels": mel,280"mel_lengths": len(mel),281"audio_lengths": len(audio),282}283284return items285286287def main():288"""Run training process."""289parser = argparse.ArgumentParser(290description="Train MelGAN (See detail in tensorflow_tts/bin/train-melgan.py)"291)292parser.add_argument(293"--train-dir",294default=None,295type=str,296help="directory including training data. ",297)298parser.add_argument(299"--dev-dir",300default=None,301type=str,302help="directory including development data. ",303)304parser.add_argument(305"--use-norm", default=1, type=int, help="use norm mels for training or raw."306)307parser.add_argument(308"--outdir", type=str, required=True, help="directory to save checkpoints."309)310parser.add_argument(311"--config", type=str, required=True, help="yaml format configuration file."312)313parser.add_argument(314"--resume",315default="",316type=str,317nargs="?",318help='checkpoint file path to resume training. (default="")',319)320parser.add_argument(321"--verbose",322type=int,323default=1,324help="logging level. higher is more logging. (default=1)",325)326parser.add_argument(327"--generator_mixed_precision",328default=0,329type=int,330help="using mixed precision for generator or not.",331)332parser.add_argument(333"--discriminator_mixed_precision",334default=0,335type=int,336help="using mixed precision for discriminator or not.",337)338parser.add_argument(339"--pretrained",340default="",341type=str,342nargs="?",343help="path of .h5 melgan generator to load weights from",344)345args = parser.parse_args()346347# return strategy348STRATEGY = return_strategy()349350# set mixed precision config351if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:352tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})353354args.generator_mixed_precision = bool(args.generator_mixed_precision)355args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)356357args.use_norm = bool(args.use_norm)358359# set logger360if args.verbose > 1:361logging.basicConfig(362level=logging.DEBUG,363stream=sys.stdout,364format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",365)366elif args.verbose > 0:367logging.basicConfig(368level=logging.INFO,369stream=sys.stdout,370format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",371)372else:373logging.basicConfig(374level=logging.WARN,375stream=sys.stdout,376format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",377)378logging.warning("Skip DEBUG/INFO messages")379380# check directory existence381if not os.path.exists(args.outdir):382os.makedirs(args.outdir)383384# check arguments385if args.train_dir is None:386raise ValueError("Please specify --train-dir")387if args.dev_dir is None:388raise ValueError("Please specify either --valid-dir")389390# load and save config391with open(args.config) as f:392config = yaml.load(f, Loader=yaml.Loader)393config.update(vars(args))394config["version"] = tensorflow_tts.__version__395with open(os.path.join(args.outdir, "config.yml"), "w") as f:396yaml.dump(config, f, Dumper=yaml.Dumper)397for key, value in config.items():398logging.info(f"{key} = {value}")399400# get dataset401if config["remove_short_samples"]:402mel_length_threshold = config["batch_max_steps"] // config[403"hop_size"404] + 2 * config["melgan_generator_params"].get("aux_context_window", 0)405else:406mel_length_threshold = None407408if config["format"] == "npy":409audio_query = "*-wave.npy"410mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"411audio_load_fn = np.load412mel_load_fn = np.load413else:414raise ValueError("Only npy are supported.")415416# define train/valid dataset417train_dataset = AudioMelDataset(418root_dir=args.train_dir,419audio_query=audio_query,420mel_query=mel_query,421audio_load_fn=audio_load_fn,422mel_load_fn=mel_load_fn,423mel_length_threshold=mel_length_threshold,424).create(425is_shuffle=config["is_shuffle"],426map_fn=lambda items: collater(427items,428batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),429hop_size=tf.constant(config["hop_size"], dtype=tf.int32),430),431allow_cache=config["allow_cache"],432batch_size=config["batch_size"]433* STRATEGY.num_replicas_in_sync434* config["gradient_accumulation_steps"],435)436437valid_dataset = AudioMelDataset(438root_dir=args.dev_dir,439audio_query=audio_query,440mel_query=mel_query,441audio_load_fn=audio_load_fn,442mel_load_fn=mel_load_fn,443mel_length_threshold=mel_length_threshold,444).create(445is_shuffle=config["is_shuffle"],446map_fn=lambda items: collater(447items,448batch_max_steps=tf.constant(449config["batch_max_steps_valid"], dtype=tf.int32450),451hop_size=tf.constant(config["hop_size"], dtype=tf.int32),452),453allow_cache=config["allow_cache"],454batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,455)456457# define trainer458trainer = MelganTrainer(459steps=0,460epochs=0,461config=config,462strategy=STRATEGY,463is_generator_mixed_precision=args.generator_mixed_precision,464is_discriminator_mixed_precision=args.discriminator_mixed_precision,465)466467# define generator and discriminator468with STRATEGY.scope():469generator = TFMelGANGenerator(470MELGAN_CONFIG.MelGANGeneratorConfig(**config["melgan_generator_params"]),471name="melgan_generator",472)473474discriminator = TFMelGANMultiScaleDiscriminator(475MELGAN_CONFIG.MelGANDiscriminatorConfig(476**config["melgan_discriminator_params"]477),478name="melgan_discriminator",479)480481# dummy input to build model.482fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)483y_hat = generator(fake_mels)484discriminator(y_hat)485486if len(args.pretrained) > 1:487generator.load_weights(args.pretrained)488logging.info(489f"Successfully loaded pretrained weight from {args.pretrained}."490)491492generator.summary()493discriminator.summary()494495gen_optimizer = tf.keras.optimizers.Adam(**config["generator_optimizer_params"])496dis_optimizer = tf.keras.optimizers.Adam(497**config["discriminator_optimizer_params"]498)499500trainer.compile(501gen_model=generator,502dis_model=discriminator,503gen_optimizer=gen_optimizer,504dis_optimizer=dis_optimizer,505)506507# start training508try:509trainer.fit(510train_dataset,511valid_dataset,512saved_path=os.path.join(config["outdir"], "checkpoints/"),513resume=args.resume,514)515except KeyboardInterrupt:516trainer.save_checkpoint()517logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")518519520if __name__ == "__main__":521main()522523524