Path: blob/master/examples/multiband_melgan/train_multiband_melgan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Train Multi-Band MelGAN."""1516import tensorflow as tf1718physical_devices = tf.config.list_physical_devices("GPU")19for i in range(len(physical_devices)):20tf.config.experimental.set_memory_growth(physical_devices[i], True)2122import sys2324sys.path.append(".")2526import argparse27import logging28import os2930import numpy as np31import soundfile as sf32import yaml33from tensorflow.keras.mixed_precision import experimental as mixed_precision3435import tensorflow_tts36from examples.melgan.audio_mel_dataset import AudioMelDataset37from examples.melgan.train_melgan import MelganTrainer, collater38from tensorflow_tts.configs import (39MultiBandMelGANDiscriminatorConfig,40MultiBandMelGANGeneratorConfig,41)42from tensorflow_tts.losses import TFMultiResolutionSTFT43from tensorflow_tts.models import (44TFPQMF,45TFMelGANGenerator,46TFMelGANMultiScaleDiscriminator,47)48from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy495051class MultiBandMelganTrainer(MelganTrainer):52"""Multi-Band MelGAN Trainer class based on MelganTrainer."""5354def __init__(55self,56config,57strategy,58steps=0,59epochs=0,60is_generator_mixed_precision=False,61is_discriminator_mixed_precision=False,62):63"""Initialize trainer.6465Args:66steps (int): Initial global steps.67epochs (int): Initial global epochs.68config (dict): Config dict loaded from yaml format configuration file.69is_generator_mixed_precision (bool): Use mixed precision for generator or not.70is_discriminator_mixed_precision (bool): Use mixed precision for discriminator or not.7172"""73super(MultiBandMelganTrainer, self).__init__(74config=config,75steps=steps,76epochs=epochs,77strategy=strategy,78is_generator_mixed_precision=is_generator_mixed_precision,79is_discriminator_mixed_precision=is_discriminator_mixed_precision,80)8182# define metrics to aggregates data and use tf.summary logs them83self.list_metrics_name = [84"adversarial_loss",85"subband_spectral_convergence_loss",86"subband_log_magnitude_loss",87"fullband_spectral_convergence_loss",88"fullband_log_magnitude_loss",89"gen_loss",90"real_loss",91"fake_loss",92"dis_loss",93]9495self.init_train_eval_metrics(self.list_metrics_name)96self.reset_states_train()97self.reset_states_eval()9899def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer, pqmf):100super().compile(gen_model, dis_model, gen_optimizer, dis_optimizer)101# define loss102self.sub_band_stft_loss = TFMultiResolutionSTFT(103**self.config["subband_stft_loss_params"]104)105self.full_band_stft_loss = TFMultiResolutionSTFT(106**self.config["stft_loss_params"]107)108109# define pqmf module110self.pqmf = pqmf111112def compute_per_example_generator_losses(self, batch, outputs):113"""Compute per example generator losses and return dict_metrics_losses114Note that all element of the loss MUST has a shape [batch_size] and115the keys of dict_metrics_losses MUST be in self.list_metrics_name.116117Args:118batch: dictionary batch input return from dataloader119outputs: outputs of the model120121Returns:122per_example_losses: per example losses for each GPU, shape [B]123dict_metrics_losses: dictionary loss.124"""125dict_metrics_losses = {}126per_example_losses = 0.0127128audios = batch["audios"]129y_mb_hat = outputs130y_hat = self.pqmf.synthesis(y_mb_hat)131132y_mb = self.pqmf.analysis(tf.expand_dims(audios, -1))133y_mb = tf.transpose(y_mb, (0, 2, 1)) # [B, subbands, T//subbands]134y_mb = tf.reshape(y_mb, (-1, tf.shape(y_mb)[-1])) # [B * subbands, T']135136y_mb_hat = tf.transpose(y_mb_hat, (0, 2, 1)) # [B, subbands, T//subbands]137y_mb_hat = tf.reshape(138y_mb_hat, (-1, tf.shape(y_mb_hat)[-1])139) # [B * subbands, T']140141# calculate sub/full band spectral_convergence and log mag loss.142sub_sc_loss, sub_mag_loss = calculate_2d_loss(143y_mb, y_mb_hat, self.sub_band_stft_loss144)145sub_sc_loss = tf.reduce_mean(146tf.reshape(sub_sc_loss, [-1, self.pqmf.subbands]), -1147)148sub_mag_loss = tf.reduce_mean(149tf.reshape(sub_mag_loss, [-1, self.pqmf.subbands]), -1150)151full_sc_loss, full_mag_loss = calculate_2d_loss(152audios, tf.squeeze(y_hat, -1), self.full_band_stft_loss153)154155# define generator loss156gen_loss = 0.5 * (sub_sc_loss + sub_mag_loss) + 0.5 * (157full_sc_loss + full_mag_loss158)159160if self.steps >= self.config["discriminator_train_start_steps"]:161p_hat = self._discriminator(y_hat)162p = self._discriminator(tf.expand_dims(audios, 2))163adv_loss = 0.0164for i in range(len(p_hat)):165adv_loss += calculate_3d_loss(166tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss167)168adv_loss /= i + 1169gen_loss += self.config["lambda_adv"] * adv_loss170171dict_metrics_losses.update({"adversarial_loss": adv_loss},)172173dict_metrics_losses.update({"gen_loss": gen_loss})174dict_metrics_losses.update({"subband_spectral_convergence_loss": sub_sc_loss})175dict_metrics_losses.update({"subband_log_magnitude_loss": sub_mag_loss})176dict_metrics_losses.update({"fullband_spectral_convergence_loss": full_sc_loss})177dict_metrics_losses.update({"fullband_log_magnitude_loss": full_mag_loss})178179per_example_losses = gen_loss180return per_example_losses, dict_metrics_losses181182def compute_per_example_discriminator_losses(self, batch, gen_outputs):183"""Compute per example discriminator losses and return dict_metrics_losses184Note that all element of the loss MUST has a shape [batch_size] and185the keys of dict_metrics_losses MUST be in self.list_metrics_name.186187Args:188batch: dictionary batch input return from dataloader189outputs: outputs of the model190191Returns:192per_example_losses: per example losses for each GPU, shape [B]193dict_metrics_losses: dictionary loss.194"""195y_mb_hat = gen_outputs196y_hat = self.pqmf.synthesis(y_mb_hat)197(198per_example_losses,199dict_metrics_losses,200) = super().compute_per_example_discriminator_losses(batch, y_hat)201return per_example_losses, dict_metrics_losses202203def generate_and_save_intermediate_result(self, batch):204"""Generate and save intermediate result."""205import matplotlib.pyplot as plt206207y_mb_batch_ = self.one_step_predict(batch) # [B, T // subbands, subbands]208y_batch = batch["audios"]209utt_ids = batch["utt_ids"]210211# convert to tensor.212# here we just take a sample at first replica.213try:214y_mb_batch_ = y_mb_batch_.values[0].numpy()215y_batch = y_batch.values[0].numpy()216utt_ids = utt_ids.values[0].numpy()217except Exception:218y_mb_batch_ = y_mb_batch_.numpy()219y_batch = y_batch.numpy()220utt_ids = utt_ids.numpy()221222y_batch_ = self.pqmf.synthesis(y_mb_batch_).numpy() # [B, T, 1]223224# check directory225dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")226if not os.path.exists(dirname):227os.makedirs(dirname)228229for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 0):230# convert to ndarray231y, y_ = tf.reshape(y, [-1]).numpy(), tf.reshape(y_, [-1]).numpy()232233# plit figure and save it234utt_id = utt_ids[idx]235figname = os.path.join(dirname, f"{utt_id}.png")236plt.subplot(2, 1, 1)237plt.plot(y)238plt.title("groundtruth speech")239plt.subplot(2, 1, 2)240plt.plot(y_)241plt.title(f"generated speech @ {self.steps} steps")242plt.tight_layout()243plt.savefig(figname)244plt.close()245246# save as wavefile247y = np.clip(y, -1, 1)248y_ = np.clip(y_, -1, 1)249sf.write(250figname.replace(".png", "_ref.wav"),251y,252self.config["sampling_rate"],253"PCM_16",254)255sf.write(256figname.replace(".png", "_gen.wav"),257y_,258self.config["sampling_rate"],259"PCM_16",260)261262263def main():264"""Run training process."""265parser = argparse.ArgumentParser(266description="Train MultiBand MelGAN (See detail in examples/multiband_melgan/train_multiband_melgan.py)"267)268parser.add_argument(269"--train-dir",270default=None,271type=str,272help="directory including training data. ",273)274parser.add_argument(275"--dev-dir",276default=None,277type=str,278help="directory including development data. ",279)280parser.add_argument(281"--use-norm", default=1, type=int, help="use norm mels for training or raw."282)283parser.add_argument(284"--outdir", type=str, required=True, help="directory to save checkpoints."285)286parser.add_argument(287"--config", type=str, required=True, help="yaml format configuration file."288)289parser.add_argument(290"--resume",291default="",292type=str,293nargs="?",294help='checkpoint file path to resume training. (default="")',295)296parser.add_argument(297"--verbose",298type=int,299default=1,300help="logging level. higher is more logging. (default=1)",301)302parser.add_argument(303"--generator_mixed_precision",304default=0,305type=int,306help="using mixed precision for generator or not.",307)308parser.add_argument(309"--discriminator_mixed_precision",310default=0,311type=int,312help="using mixed precision for discriminator or not.",313)314parser.add_argument(315"--pretrained",316default="",317type=str,318nargs="?",319help="path of .h5 mb-melgan generator to load weights from",320)321args = parser.parse_args()322323# return strategy324STRATEGY = return_strategy()325326# set mixed precision config327if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:328tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})329330args.generator_mixed_precision = bool(args.generator_mixed_precision)331args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision)332333args.use_norm = bool(args.use_norm)334335# set logger336if args.verbose > 1:337logging.basicConfig(338level=logging.DEBUG,339stream=sys.stdout,340format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",341)342elif args.verbose > 0:343logging.basicConfig(344level=logging.INFO,345stream=sys.stdout,346format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",347)348else:349logging.basicConfig(350level=logging.WARN,351stream=sys.stdout,352format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",353)354logging.warning("Skip DEBUG/INFO messages")355356# check directory existence357if not os.path.exists(args.outdir):358os.makedirs(args.outdir)359360# check arguments361if args.train_dir is None:362raise ValueError("Please specify --train-dir")363if args.dev_dir is None:364raise ValueError("Please specify either --valid-dir")365366# load and save config367with open(args.config) as f:368config = yaml.load(f, Loader=yaml.Loader)369config.update(vars(args))370config["version"] = tensorflow_tts.__version__371with open(os.path.join(args.outdir, "config.yml"), "w") as f:372yaml.dump(config, f, Dumper=yaml.Dumper)373for key, value in config.items():374logging.info(f"{key} = {value}")375376# get dataset377if config["remove_short_samples"]:378mel_length_threshold = config["batch_max_steps"] // config[379"hop_size"380] + 2 * config["multiband_melgan_generator_params"].get("aux_context_window", 0)381else:382mel_length_threshold = None383384if config["format"] == "npy":385audio_query = "*-wave.npy"386mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"387audio_load_fn = np.load388mel_load_fn = np.load389else:390raise ValueError("Only npy are supported.")391392# define train/valid dataset393train_dataset = AudioMelDataset(394root_dir=args.train_dir,395audio_query=audio_query,396mel_query=mel_query,397audio_load_fn=audio_load_fn,398mel_load_fn=mel_load_fn,399mel_length_threshold=mel_length_threshold,400).create(401is_shuffle=config["is_shuffle"],402map_fn=lambda items: collater(403items,404batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32),405hop_size=tf.constant(config["hop_size"], dtype=tf.int32),406),407allow_cache=config["allow_cache"],408batch_size=config["batch_size"]409* STRATEGY.num_replicas_in_sync410* config["gradient_accumulation_steps"],411)412413valid_dataset = AudioMelDataset(414root_dir=args.dev_dir,415audio_query=audio_query,416mel_query=mel_query,417audio_load_fn=audio_load_fn,418mel_load_fn=mel_load_fn,419mel_length_threshold=mel_length_threshold,420).create(421is_shuffle=config["is_shuffle"],422map_fn=lambda items: collater(423items,424batch_max_steps=tf.constant(425config["batch_max_steps_valid"], dtype=tf.int32426),427hop_size=tf.constant(config["hop_size"], dtype=tf.int32),428),429allow_cache=config["allow_cache"],430batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,431)432433# define trainer434trainer = MultiBandMelganTrainer(435steps=0,436epochs=0,437config=config,438strategy=STRATEGY,439is_generator_mixed_precision=args.generator_mixed_precision,440is_discriminator_mixed_precision=args.discriminator_mixed_precision,441)442443with STRATEGY.scope():444# define generator and discriminator445generator = TFMelGANGenerator(446MultiBandMelGANGeneratorConfig(447**config["multiband_melgan_generator_params"]448),449name="multi_band_melgan_generator",450)451452discriminator = TFMelGANMultiScaleDiscriminator(453MultiBandMelGANDiscriminatorConfig(454**config["multiband_melgan_discriminator_params"]455),456name="multi_band_melgan_discriminator",457)458459pqmf = TFPQMF(460MultiBandMelGANGeneratorConfig(461**config["multiband_melgan_generator_params"]462),463dtype=tf.float32,464name="pqmf",465)466467# dummy input to build model.468fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)469y_mb_hat = generator(fake_mels)470y_hat = pqmf.synthesis(y_mb_hat)471discriminator(y_hat)472473if len(args.pretrained) > 1:474generator.load_weights(args.pretrained)475logging.info(476f"Successfully loaded pretrained weight from {args.pretrained}."477)478479generator.summary()480discriminator.summary()481482# define optimizer483generator_lr_fn = getattr(484tf.keras.optimizers.schedules, config["generator_optimizer_params"]["lr_fn"]485)(**config["generator_optimizer_params"]["lr_params"])486discriminator_lr_fn = getattr(487tf.keras.optimizers.schedules,488config["discriminator_optimizer_params"]["lr_fn"],489)(**config["discriminator_optimizer_params"]["lr_params"])490491gen_optimizer = tf.keras.optimizers.Adam(492learning_rate=generator_lr_fn,493amsgrad=config["generator_optimizer_params"]["amsgrad"],494)495dis_optimizer = tf.keras.optimizers.Adam(496learning_rate=discriminator_lr_fn,497amsgrad=config["discriminator_optimizer_params"]["amsgrad"],498)499500trainer.compile(501gen_model=generator,502dis_model=discriminator,503gen_optimizer=gen_optimizer,504dis_optimizer=dis_optimizer,505pqmf=pqmf,506)507508# start training509try:510trainer.fit(511train_dataset,512valid_dataset,513saved_path=os.path.join(config["outdir"], "checkpoints/"),514resume=args.resume,515)516except KeyboardInterrupt:517trainer.save_checkpoint()518logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")519520521if __name__ == "__main__":522main()523524525