Path: blob/master/examples/melgan_stft/train_melgan_stft.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Train MelGAN Multi Resolution STFT Loss."""1516import tensorflow as tf1718physical_devices = tf.config.list_physical_devices("GPU")19for i in range(len(physical_devices)):20tf.config.experimental.set_memory_growth(physical_devices[i], True)2122import sys2324sys.path.append(".")2526import argparse27import logging28import os2930import numpy as np31import yaml3233import tensorflow_tts34import tensorflow_tts.configs.melgan as MELGAN_CONFIG35from examples.melgan.audio_mel_dataset import AudioMelDataset36from examples.melgan.train_melgan import MelganTrainer, collater37from tensorflow_tts.losses import TFMultiResolutionSTFT38from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator39from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy404142class MultiSTFTMelganTrainer(MelganTrainer):43"""Multi STFT Melgan Trainer class based on MelganTrainer."""4445def __init__(46self,47config,48strategy,49steps=0,50epochs=0,51is_generator_mixed_precision=False,52is_discriminator_mixed_precision=False,53):54"""Initialize trainer.5556Args:57steps (int): Initial global steps.58epochs (int): Initial global epochs.59config (dict): Config dict loaded from yaml format configuration file.60is_generator_mixed_precision (bool): Use mixed precision for generator or not.61is_discriminator_mixed_precision (bool): Use mixed precision for discriminator or not.6263"""64super(MultiSTFTMelganTrainer, self).__init__(65config=config,66steps=steps,67epochs=epochs,68strategy=strategy,69is_generator_mixed_precision=is_generator_mixed_precision,70is_discriminator_mixed_precision=is_discriminator_mixed_precision,71)7273self.list_metrics_name = [74"adversarial_loss",75"fm_loss",76"gen_loss",77"real_loss",78"fake_loss",79"dis_loss",80"spectral_convergence_loss",81"log_magnitude_loss",82]8384self.init_train_eval_metrics(self.list_metrics_name)85self.reset_states_train()86self.reset_states_eval()8788def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer):89super().compile(gen_model, dis_model, gen_optimizer, dis_optimizer)90# define loss91self.stft_loss = TFMultiResolutionSTFT(**self.config["stft_loss_params"])9293def compute_per_example_generator_losses(self, batch, outputs):94"""Compute per example generator losses and return dict_metrics_losses95Note that all element of the loss MUST has a shape [batch_size] and96the keys of dict_metrics_losses MUST be in self.list_metrics_name.9798Args:99batch: dictionary batch input return from dataloader100outputs: outputs of the model101102Returns:103per_example_losses: per example losses for each GPU, shape [B]104dict_metrics_losses: dictionary loss.105"""106dict_metrics_losses = {}107per_example_losses = 0.0108109audios = batch["audios"]110y_hat = outputs111112# calculate multi-resolution stft loss113sc_loss, mag_loss = calculate_2d_loss(114audios, tf.squeeze(y_hat, -1), self.stft_loss115)116117# trick to prevent loss expoded here118sc_loss = tf.where(sc_loss >= 15.0, 0.0, sc_loss)119mag_loss = tf.where(mag_loss >= 15.0, 0.0, mag_loss)120121# compute generator loss122gen_loss = 0.5 * (sc_loss + mag_loss)123124if self.steps >= self.config["discriminator_train_start_steps"]:125p_hat = self._discriminator(y_hat)126p = self._discriminator(tf.expand_dims(audios, 2))127adv_loss = 0.0128for i in range(len(p_hat)):129adv_loss += calculate_3d_loss(130tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss131)132adv_loss /= i + 1133134# define feature-matching loss135fm_loss = 0.0136for i in range(len(p_hat)):137for j in range(len(p_hat[i]) - 1):138fm_loss += calculate_3d_loss(139p[i][j], p_hat[i][j], loss_fn=self.mae_loss140)141fm_loss /= (i + 1) * (j + 1)142adv_loss += self.config["lambda_feat_match"] * fm_loss143gen_loss += self.config["lambda_adv"] * adv_loss144145dict_metrics_losses.update({"adversarial_loss": adv_loss})146dict_metrics_losses.update({"fm_loss": fm_loss})147148dict_metrics_losses.update({"gen_loss": gen_loss})149dict_metrics_losses.update({"spectral_convergence_loss": sc_loss})150dict_metrics_losses.update({"log_magnitude_loss": mag_loss})151152per_example_losses = gen_loss153return per_example_losses, dict_metrics_losses154155156def main():157"""Run training process."""158parser = argparse.ArgumentParser(159description="Train MelGAN (See detail in tensorflow_tts/bin/train-melgan.py)"160)161parser.add_argument(162"--train-dir",163default=None,164type=str,165help="directory including training data. ",166)167parser.add_argument(168"--dev-dir",169default=None,170type=str,171help="directory including development data. ",172)173parser.add_argument(174"--use-norm", default=1, type=int, help="use norm mels for training or raw."175)176parser.add_argument(177"--outdir", type=str, required=True, help="directory to save checkpoints."178)179parser.add_argument(180"--config", type=str, required=True, help="yaml format configuration file."181)182parser.add_argument(183"--resume",184default="",185type=str,186nargs="?",187help='checkpoint file path to resume training. (default="")',188)189parser.add_argument(190"--verbose",191type=int,192default=1,193help="logging level. higher is more logging. (default=1)",194)195parser.add_argument(196"--generator_mixed_precision",197default=0,198type=int,199help="using mixed precision for generator or not.",200)201parser.add_argument(202"--discriminator_mixed_precision",203default=0,204type=int,205help="using mixed precision for discriminator or not.",206)207parser.add_argument(208"--pretrained",209default="",210type=str,211nargs="?",212help="path of .h5 melgan generator to load weights from",213)214args = parser.parse_args()215216# return strategy217STRATEGY = return_strategy()218219# set mixed precision config220if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:221tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})222223args.generator_mixed_precision = bool(args.generator_mixed_precision)224args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)225226args.use_norm = bool(args.use_norm)227228# set logger229if args.verbose > 1:230logging.basicConfig(231level=logging.DEBUG,232stream=sys.stdout,233format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",234)235elif args.verbose > 0:236logging.basicConfig(237level=logging.INFO,238stream=sys.stdout,239format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",240)241else:242logging.basicConfig(243level=logging.WARN,244stream=sys.stdout,245format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",246)247logging.warning("Skip DEBUG/INFO messages")248249# check directory existence250if not os.path.exists(args.outdir):251os.makedirs(args.outdir)252253# check arguments254if args.train_dir is None:255raise ValueError("Please specify --train-dir")256if args.dev_dir is None:257raise ValueError("Please specify either --valid-dir")258259# load and save config260with open(args.config) as f:261config = yaml.load(f, Loader=yaml.Loader)262config.update(vars(args))263config["version"] = tensorflow_tts.__version__264with open(os.path.join(args.outdir, "config.yml"), "w") as f:265yaml.dump(config, f, Dumper=yaml.Dumper)266for key, value in config.items():267logging.info(f"{key} = {value}")268269# get dataset270if config["remove_short_samples"]:271mel_length_threshold = config["batch_max_steps"] // config[272"hop_size"273] + 2 * config["melgan_generator_params"].get("aux_context_window", 0)274else:275mel_length_threshold = None276277if config["format"] == "npy":278audio_query = "*-wave.npy"279mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"280audio_load_fn = np.load281mel_load_fn = np.load282else:283raise ValueError("Only npy are supported.")284285# define train/valid dataset286train_dataset = AudioMelDataset(287root_dir=args.train_dir,288audio_query=audio_query,289mel_query=mel_query,290audio_load_fn=audio_load_fn,291mel_load_fn=mel_load_fn,292mel_length_threshold=mel_length_threshold,293).create(294is_shuffle=config["is_shuffle"],295map_fn=lambda items: collater(296items,297batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),298hop_size=tf.constant(config["hop_size"], dtype=tf.int32),299),300allow_cache=config["allow_cache"],301batch_size=config["batch_size"]302* STRATEGY.num_replicas_in_sync303* config["gradient_accumulation_steps"],304)305306valid_dataset = AudioMelDataset(307root_dir=args.dev_dir,308audio_query=audio_query,309mel_query=mel_query,310audio_load_fn=audio_load_fn,311mel_load_fn=mel_load_fn,312mel_length_threshold=mel_length_threshold,313).create(314is_shuffle=config["is_shuffle"],315map_fn=lambda items: collater(316items,317batch_max_steps=tf.constant(318config["batch_max_steps_valid"], dtype=tf.int32319),320hop_size=tf.constant(config["hop_size"], dtype=tf.int32),321),322allow_cache=config["allow_cache"],323batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,324)325326# define trainer327trainer = MultiSTFTMelganTrainer(328steps=0,329epochs=0,330config=config,331strategy=STRATEGY,332is_generator_mixed_precision=args.generator_mixed_precision,333is_discriminator_mixed_precision=args.discriminator_mixed_precision,334)335336with STRATEGY.scope():337# define generator and discriminator338generator = TFMelGANGenerator(339MELGAN_CONFIG.MelGANGeneratorConfig(**config["melgan_generator_params"]),340name="melgan_generator",341)342343discriminator = TFMelGANMultiScaleDiscriminator(344MELGAN_CONFIG.MelGANDiscriminatorConfig(345**config["melgan_discriminator_params"]346),347name="melgan_discriminator",348)349350# dummy input to build model.351fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)352y_hat = generator(fake_mels)353discriminator(y_hat)354355if len(args.pretrained) > 1:356generator.load_weights(args.pretrained)357logging.info(358f"Successfully loaded pretrained weight from {args.pretrained}."359)360361generator.summary()362discriminator.summary()363364# define optimizer365generator_lr_fn = getattr(366tf.keras.optimizers.schedules, config["generator_optimizer_params"]["lr_fn"]367)(**config["generator_optimizer_params"]["lr_params"])368discriminator_lr_fn = getattr(369tf.keras.optimizers.schedules,370config["discriminator_optimizer_params"]["lr_fn"],371)(**config["discriminator_optimizer_params"]["lr_params"])372373gen_optimizer = tf.keras.optimizers.Adam(374learning_rate=generator_lr_fn, amsgrad=False375)376dis_optimizer = tf.keras.optimizers.Adam(377learning_rate=discriminator_lr_fn, amsgrad=False378)379380trainer.compile(381gen_model=generator,382dis_model=discriminator,383gen_optimizer=gen_optimizer,384dis_optimizer=dis_optimizer,385)386387# start training388try:389trainer.fit(390train_dataset,391valid_dataset,392saved_path=os.path.join(config["outdir"], "checkpoints/"),393resume=args.resume,394)395except KeyboardInterrupt:396trainer.save_checkpoint()397logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")398399400if __name__ == "__main__":401main()402403404