Path: blob/master/examples/parallel_wavegan/train_parallel_wavegan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Train ParallelWavegan."""1516import tensorflow as tf1718physical_devices = tf.config.list_physical_devices("GPU")19for i in range(len(physical_devices)):20tf.config.experimental.set_memory_growth(physical_devices[i], True)2122import sys2324sys.path.append(".")2526import argparse27import logging28import os29import soundfile as sf3031import numpy as np32import yaml3334import tensorflow_tts3536from examples.melgan.audio_mel_dataset import AudioMelDataset37from examples.melgan.train_melgan import collater3839from tensorflow_tts.configs import (40ParallelWaveGANGeneratorConfig,41ParallelWaveGANDiscriminatorConfig,42)43from tensorflow_tts.models import (44TFParallelWaveGANGenerator,45TFParallelWaveGANDiscriminator,46)4748from tensorflow_tts.trainers import GanBasedTrainer49from tensorflow_tts.losses import TFMultiResolutionSTFT50from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy5152from tensorflow_addons.optimizers import RectifiedAdam535455class ParallelWaveganTrainer(GanBasedTrainer):56"""ParallelWaveGAN Trainer class based on GanBasedTrainer."""5758def __init__(59self,60config,61strategy,62steps=0,63epochs=0,64is_generator_mixed_precision=False,65is_discriminator_mixed_precision=False,66):67"""Initialize trainer.6869Args:70steps (int): Initial global steps.71epochs (int): Initial global epochs.72config (dict): Config dict loaded from yaml format configuration file.73is_generator_mixed_precision (bool): Use mixed precision for generator or not.74is_discriminator_mixed_precision (bool): Use mixed precision for discriminator or not.7576"""77super(ParallelWaveganTrainer, self).__init__(78config=config,79steps=steps,80epochs=epochs,81strategy=strategy,82is_generator_mixed_precision=is_generator_mixed_precision,83is_discriminator_mixed_precision=is_discriminator_mixed_precision,84)8586self.list_metrics_name = [87"adversarial_loss",88"gen_loss",89"real_loss",90"fake_loss",91"dis_loss",92"spectral_convergence_loss",93"log_magnitude_loss",94]9596self.init_train_eval_metrics(self.list_metrics_name)97self.reset_states_train()98self.reset_states_eval()99100def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer):101super().compile(gen_model, dis_model, gen_optimizer, dis_optimizer)102# define loss103self.stft_loss = TFMultiResolutionSTFT(**self.config["stft_loss_params"])104self.mse_loss = tf.keras.losses.MeanSquaredError(105reduction=tf.keras.losses.Reduction.NONE106)107self.mae_loss = tf.keras.losses.MeanAbsoluteError(108reduction=tf.keras.losses.Reduction.NONE109)110111def compute_per_example_generator_losses(self, batch, outputs):112"""Compute per example generator losses and return dict_metrics_losses113Note that all element of the loss MUST has a shape [batch_size] and114the keys of dict_metrics_losses MUST be in self.list_metrics_name.115116Args:117batch: dictionary batch input return from dataloader118outputs: outputs of the model119120Returns:121per_example_losses: per example losses for each GPU, shape [B]122dict_metrics_losses: dictionary loss.123"""124dict_metrics_losses = {}125per_example_losses = 0.0126127audios = batch["audios"]128y_hat = outputs129130# calculate multi-resolution stft loss131sc_loss, mag_loss = calculate_2d_loss(132audios, tf.squeeze(y_hat, -1), self.stft_loss133)134gen_loss = 0.5 * (sc_loss + mag_loss)135136if self.steps >= self.config["discriminator_train_start_steps"]:137p_hat = self._discriminator(y_hat)138p = self._discriminator(tf.expand_dims(audios, 2))139adv_loss = 0.0140adv_loss += calculate_3d_loss(141tf.ones_like(p_hat), p_hat, loss_fn=self.mse_loss142)143gen_loss += self.config["lambda_adv"] * adv_loss144145# update dict_metrics_losses146dict_metrics_losses.update({"adversarial_loss": adv_loss})147148dict_metrics_losses.update({"gen_loss": gen_loss})149dict_metrics_losses.update({"spectral_convergence_loss": sc_loss})150dict_metrics_losses.update({"log_magnitude_loss": mag_loss})151152per_example_losses = gen_loss153return per_example_losses, dict_metrics_losses154155def compute_per_example_discriminator_losses(self, batch, gen_outputs):156audios = batch["audios"]157y_hat = gen_outputs158159y = tf.expand_dims(audios, 2)160p = self._discriminator(y)161p_hat = self._discriminator(y_hat)162163real_loss = 0.0164fake_loss = 0.0165166real_loss += calculate_3d_loss(tf.ones_like(p), p, loss_fn=self.mse_loss)167fake_loss += calculate_3d_loss(168tf.zeros_like(p_hat), p_hat, loss_fn=self.mse_loss169)170171dis_loss = real_loss + fake_loss172173# calculate per_example_losses and dict_metrics_losses174per_example_losses = dis_loss175176dict_metrics_losses = {177"real_loss": real_loss,178"fake_loss": fake_loss,179"dis_loss": dis_loss,180}181182return per_example_losses, dict_metrics_losses183184def generate_and_save_intermediate_result(self, batch):185"""Generate and save intermediate result."""186import matplotlib.pyplot as plt187188# generate189y_batch_ = self.one_step_predict(batch)190y_batch = batch["audios"]191utt_ids = batch["utt_ids"]192193# convert to tensor.194# here we just take a sample at first replica.195try:196y_batch_ = y_batch_.values[0].numpy()197y_batch = y_batch.values[0].numpy()198utt_ids = utt_ids.values[0].numpy()199except Exception:200y_batch_ = y_batch_.numpy()201y_batch = y_batch.numpy()202utt_ids = utt_ids.numpy()203204# check directory205dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")206if not os.path.exists(dirname):207os.makedirs(dirname)208209for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 0):210# convert to ndarray211y, y_ = tf.reshape(y, [-1]).numpy(), tf.reshape(y_, [-1]).numpy()212213# plit figure and save it214utt_id = utt_ids[idx]215figname = os.path.join(dirname, f"{utt_id}.png")216plt.subplot(2, 1, 1)217plt.plot(y)218plt.title("groundtruth speech")219plt.subplot(2, 1, 2)220plt.plot(y_)221plt.title(f"generated speech @ {self.steps} steps")222plt.tight_layout()223plt.savefig(figname)224plt.close()225226# save as wavefile227y = np.clip(y, -1, 1)228y_ = np.clip(y_, -1, 1)229sf.write(230figname.replace(".png", "_ref.wav"),231y,232self.config["sampling_rate"],233"PCM_16",234)235sf.write(236figname.replace(".png", "_gen.wav"),237y_,238self.config["sampling_rate"],239"PCM_16",240)241242243def main():244"""Run training process."""245parser = argparse.ArgumentParser(246description="Train ParallelWaveGan (See detail in tensorflow_tts/examples/parallel_wavegan/train_parallel_wavegan.py)"247)248parser.add_argument(249"--train-dir",250default=None,251type=str,252help="directory including training data. ",253)254parser.add_argument(255"--dev-dir",256default=None,257type=str,258help="directory including development data. ",259)260parser.add_argument(261"--use-norm", default=1, type=int, help="use norm mels for training or raw."262)263parser.add_argument(264"--outdir", type=str, required=True, help="directory to save checkpoints."265)266parser.add_argument(267"--config", type=str, required=True, help="yaml format configuration file."268)269parser.add_argument(270"--resume",271default="",272type=str,273nargs="?",274help='checkpoint file path to resume training. (default="")',275)276parser.add_argument(277"--verbose",278type=int,279default=1,280help="logging level. higher is more logging. (default=1)",281)282parser.add_argument(283"--generator_mixed_precision",284default=0,285type=int,286help="using mixed precision for generator or not.",287)288parser.add_argument(289"--discriminator_mixed_precision",290default=0,291type=int,292help="using mixed precision for discriminator or not.",293)294args = parser.parse_args()295296# return strategy297STRATEGY = return_strategy()298299# set mixed precision config300if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:301tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})302303args.generator_mixed_precision = bool(args.generator_mixed_precision)304args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)305306args.use_norm = bool(args.use_norm)307308# set logger309if args.verbose > 1:310logging.basicConfig(311level=logging.DEBUG,312stream=sys.stdout,313format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",314)315elif args.verbose > 0:316logging.basicConfig(317level=logging.INFO,318stream=sys.stdout,319format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",320)321else:322logging.basicConfig(323level=logging.WARN,324stream=sys.stdout,325format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",326)327logging.warning("Skip DEBUG/INFO messages")328329# check directory existence330if not os.path.exists(args.outdir):331os.makedirs(args.outdir)332333# check arguments334if args.train_dir is None:335raise ValueError("Please specify --train-dir")336if args.dev_dir is None:337raise ValueError("Please specify either --valid-dir")338339# load and save config340with open(args.config) as f:341config = yaml.load(f, Loader=yaml.Loader)342config.update(vars(args))343config["version"] = tensorflow_tts.__version__344with open(os.path.join(args.outdir, "config.yml"), "w") as f:345yaml.dump(config, f, Dumper=yaml.Dumper)346for key, value in config.items():347logging.info(f"{key} = {value}")348349# get dataset350if config["remove_short_samples"]:351mel_length_threshold = config["batch_max_steps"] // config[352"hop_size"353] + 2 * config["parallel_wavegan_generator_params"].get("aux_context_window", 0)354else:355mel_length_threshold = None356357if config["format"] == "npy":358audio_query = "*-wave.npy"359mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"360audio_load_fn = np.load361mel_load_fn = np.load362else:363raise ValueError("Only npy are supported.")364365# define train/valid dataset366train_dataset = AudioMelDataset(367root_dir=args.train_dir,368audio_query=audio_query,369mel_query=mel_query,370audio_load_fn=audio_load_fn,371mel_load_fn=mel_load_fn,372mel_length_threshold=mel_length_threshold,373).create(374is_shuffle=config["is_shuffle"],375map_fn=lambda items: collater(376items,377batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),378hop_size=tf.constant(config["hop_size"], dtype=tf.int32),379),380allow_cache=config["allow_cache"],381batch_size=config["batch_size"]382* STRATEGY.num_replicas_in_sync383* config["gradient_accumulation_steps"],384)385386valid_dataset = AudioMelDataset(387root_dir=args.dev_dir,388audio_query=audio_query,389mel_query=mel_query,390audio_load_fn=audio_load_fn,391mel_load_fn=mel_load_fn,392mel_length_threshold=mel_length_threshold,393).create(394is_shuffle=config["is_shuffle"],395map_fn=lambda items: collater(396items,397batch_max_steps=tf.constant(398config["batch_max_steps_valid"], dtype=tf.int32399),400hop_size=tf.constant(config["hop_size"], dtype=tf.int32),401),402allow_cache=config["allow_cache"],403batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,404)405406# define trainer407trainer = ParallelWaveganTrainer(408steps=0,409epochs=0,410config=config,411strategy=STRATEGY,412is_generator_mixed_precision=args.generator_mixed_precision,413is_discriminator_mixed_precision=args.discriminator_mixed_precision,414)415416with STRATEGY.scope():417# define generator and discriminator418generator = TFParallelWaveGANGenerator(419ParallelWaveGANGeneratorConfig(420**config["parallel_wavegan_generator_params"]421),422name="parallel_wavegan_generator",423)424425discriminator = TFParallelWaveGANDiscriminator(426ParallelWaveGANDiscriminatorConfig(427**config["parallel_wavegan_discriminator_params"]428),429name="parallel_wavegan_discriminator",430)431432# dummy input to build model.433fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)434y_hat = generator(fake_mels)435discriminator(y_hat)436437generator.summary()438discriminator.summary()439440# define optimizer441generator_lr_fn = getattr(442tf.keras.optimizers.schedules, config["generator_optimizer_params"]["lr_fn"]443)(**config["generator_optimizer_params"]["lr_params"])444discriminator_lr_fn = getattr(445tf.keras.optimizers.schedules,446config["discriminator_optimizer_params"]["lr_fn"],447)(**config["discriminator_optimizer_params"]["lr_params"])448449gen_optimizer = RectifiedAdam(learning_rate=generator_lr_fn, amsgrad=False)450dis_optimizer = RectifiedAdam(learning_rate=discriminator_lr_fn, amsgrad=False)451452trainer.compile(453gen_model=generator,454dis_model=discriminator,455gen_optimizer=gen_optimizer,456dis_optimizer=dis_optimizer,457)458459# start training460try:461trainer.fit(462train_dataset,463valid_dataset,464saved_path=os.path.join(config["outdir"], "checkpoints/"),465resume=args.resume,466)467except KeyboardInterrupt:468trainer.save_checkpoint()469logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")470471472if __name__ == "__main__":473main()474475476