Path: blob/master/examples/fastspeech2/train_fastspeech2.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Train FastSpeech2."""1516import tensorflow as tf1718physical_devices = tf.config.list_physical_devices("GPU")19for i in range(len(physical_devices)):20tf.config.experimental.set_memory_growth(physical_devices[i], True)2122import sys2324sys.path.append(".")2526import argparse27import logging28import os2930import numpy as np31import yaml32from tqdm import tqdm3334import tensorflow_tts35from examples.fastspeech2.fastspeech2_dataset import CharactorDurationF0EnergyMelDataset36from examples.fastspeech.train_fastspeech import FastSpeechTrainer37from tensorflow_tts.configs import FastSpeech2Config38from tensorflow_tts.models import TFFastSpeech239from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp40from tensorflow_tts.trainers import Seq2SeqBasedTrainer41from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy424344class FastSpeech2Trainer(Seq2SeqBasedTrainer):45"""FastSpeech2 Trainer class based on FastSpeechTrainer."""4647def __init__(48self, config, strategy, steps=0, epochs=0, is_mixed_precision=False,49):50"""Initialize trainer.51Args:52steps (int): Initial global steps.53epochs (int): Initial global epochs.54config (dict): Config dict loaded from yaml format configuration file.55is_mixed_precision (bool): Use mixed precision or not.56"""57super(FastSpeech2Trainer, self).__init__(58steps=steps,59epochs=epochs,60config=config,61strategy=strategy,62is_mixed_precision=is_mixed_precision,63)64# define metrics to aggregates data and use tf.summary logs them65self.list_metrics_name = [66"duration_loss",67"f0_loss",68"energy_loss",69"mel_loss_before",70"mel_loss_after",71]72self.init_train_eval_metrics(self.list_metrics_name)73self.reset_states_train()74self.reset_states_eval()7576def compile(self, model, optimizer):77super().compile(model, optimizer)78self.mse = tf.keras.losses.MeanSquaredError(79reduction=tf.keras.losses.Reduction.NONE80)81self.mae = tf.keras.losses.MeanAbsoluteError(82reduction=tf.keras.losses.Reduction.NONE83)8485def compute_per_example_losses(self, batch, outputs):86"""Compute per example losses and return dict_metrics_losses87Note that all element of the loss MUST has a shape [batch_size] and88the keys of dict_metrics_losses MUST be in self.list_metrics_name.8990Args:91batch: dictionary batch input return from dataloader92outputs: outputs of the model9394Returns:95per_example_losses: per example losses for each GPU, shape [B]96dict_metrics_losses: dictionary loss.97"""98mel_before, mel_after, duration_outputs, f0_outputs, energy_outputs = outputs99100log_duration = tf.math.log(101tf.cast(tf.math.add(batch["duration_gts"], 1), tf.float32)102)103duration_loss = calculate_2d_loss(log_duration, duration_outputs, self.mse)104f0_loss = calculate_2d_loss(batch["f0_gts"], f0_outputs, self.mse)105energy_loss = calculate_2d_loss(batch["energy_gts"], energy_outputs, self.mse)106mel_loss_before = calculate_3d_loss(batch["mel_gts"], mel_before, self.mae)107mel_loss_after = calculate_3d_loss(batch["mel_gts"], mel_after, self.mae)108109per_example_losses = (110duration_loss + f0_loss + energy_loss + mel_loss_before + mel_loss_after111)112113dict_metrics_losses = {114"duration_loss": duration_loss,115"f0_loss": f0_loss,116"energy_loss": energy_loss,117"mel_loss_before": mel_loss_before,118"mel_loss_after": mel_loss_after,119}120121return per_example_losses, dict_metrics_losses122123def generate_and_save_intermediate_result(self, batch):124"""Generate and save intermediate result."""125import matplotlib.pyplot as plt126127# predict with tf.function.128outputs = self.one_step_predict(batch)129130mels_before, mels_after, *_ = outputs131mel_gts = batch["mel_gts"]132utt_ids = batch["utt_ids"]133134# convert to tensor.135# here we just take a sample at first replica.136try:137mels_before = mels_before.values[0].numpy()138mels_after = mels_after.values[0].numpy()139mel_gts = mel_gts.values[0].numpy()140utt_ids = utt_ids.values[0].numpy()141except Exception:142mels_before = mels_before.numpy()143mels_after = mels_after.numpy()144mel_gts = mel_gts.numpy()145utt_ids = utt_ids.numpy()146147# check directory148dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")149if not os.path.exists(dirname):150os.makedirs(dirname)151152for idx, (mel_gt, mel_before, mel_after) in enumerate(153zip(mel_gts, mels_before, mels_after), 0154):155mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80]156mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80]157mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80]158159# plit figure and save it160utt_id = utt_ids[idx]161figname = os.path.join(dirname, f"{utt_id}.png")162fig = plt.figure(figsize=(10, 8))163ax1 = fig.add_subplot(311)164ax2 = fig.add_subplot(312)165ax3 = fig.add_subplot(313)166im = ax1.imshow(np.rot90(mel_gt), aspect="auto", interpolation="none")167ax1.set_title("Target Mel-Spectrogram")168fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)169ax2.set_title("Predicted Mel-before-Spectrogram")170im = ax2.imshow(np.rot90(mel_before), aspect="auto", interpolation="none")171fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)172ax3.set_title("Predicted Mel-after-Spectrogram")173im = ax3.imshow(np.rot90(mel_after), aspect="auto", interpolation="none")174fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax3)175plt.tight_layout()176plt.savefig(figname)177plt.close()178179180def main():181"""Run training process."""182parser = argparse.ArgumentParser(183description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)"184)185parser.add_argument(186"--train-dir",187default=None,188type=str,189help="directory including training data. ",190)191parser.add_argument(192"--dev-dir",193default=None,194type=str,195help="directory including development data. ",196)197parser.add_argument(198"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."199)200parser.add_argument(201"--f0-stat",202default="./dump/stats_f0.npy",203type=str,204required=True,205help="f0-stat path.",206)207parser.add_argument(208"--energy-stat",209default="./dump/stats_energy.npy",210type=str,211required=True,212help="energy-stat path.",213)214parser.add_argument(215"--outdir", type=str, required=True, help="directory to save checkpoints."216)217parser.add_argument(218"--config", type=str, required=True, help="yaml format configuration file."219)220parser.add_argument(221"--resume",222default="",223type=str,224nargs="?",225help='checkpoint file path to resume training. (default="")',226)227parser.add_argument(228"--verbose",229type=int,230default=1,231help="logging level. higher is more logging. (default=1)",232)233parser.add_argument(234"--mixed_precision",235default=0,236type=int,237help="using mixed precision for generator or not.",238)239parser.add_argument(240"--pretrained",241default="",242type=str,243nargs="?",244help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",245)246247args = parser.parse_args()248249# return strategy250STRATEGY = return_strategy()251252# set mixed precision config253if args.mixed_precision == 1:254tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})255256args.mixed_precision = bool(args.mixed_precision)257args.use_norm = bool(args.use_norm)258259# set logger260if args.verbose > 1:261logging.basicConfig(262level=logging.DEBUG,263stream=sys.stdout,264format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",265)266elif args.verbose > 0:267logging.basicConfig(268level=logging.INFO,269stream=sys.stdout,270format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",271)272else:273logging.basicConfig(274level=logging.WARN,275stream=sys.stdout,276format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",277)278logging.warning("Skip DEBUG/INFO messages")279280# check directory existence281if not os.path.exists(args.outdir):282os.makedirs(args.outdir)283284# check arguments285if args.train_dir is None:286raise ValueError("Please specify --train-dir")287if args.dev_dir is None:288raise ValueError("Please specify --valid-dir")289290# load and save config291with open(args.config) as f:292config = yaml.load(f, Loader=yaml.Loader)293config.update(vars(args))294config["version"] = tensorflow_tts.__version__295with open(os.path.join(args.outdir, "config.yml"), "w") as f:296yaml.dump(config, f, Dumper=yaml.Dumper)297for key, value in config.items():298logging.info(f"{key} = {value}")299300# get dataset301if config["remove_short_samples"]:302mel_length_threshold = config["mel_length_threshold"]303else:304mel_length_threshold = None305306if config["format"] == "npy":307charactor_query = "*-ids.npy"308mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"309duration_query = "*-durations.npy"310f0_query = "*-raw-f0.npy"311energy_query = "*-raw-energy.npy"312else:313raise ValueError("Only npy are supported.")314315# define train/valid dataset316train_dataset = CharactorDurationF0EnergyMelDataset(317root_dir=args.train_dir,318charactor_query=charactor_query,319mel_query=mel_query,320duration_query=duration_query,321f0_query=f0_query,322energy_query=energy_query,323f0_stat=args.f0_stat,324energy_stat=args.energy_stat,325mel_length_threshold=mel_length_threshold,326).create(327is_shuffle=config["is_shuffle"],328allow_cache=config["allow_cache"],329batch_size=config["batch_size"]330* STRATEGY.num_replicas_in_sync331* config["gradient_accumulation_steps"],332)333334valid_dataset = CharactorDurationF0EnergyMelDataset(335root_dir=args.dev_dir,336charactor_query=charactor_query,337mel_query=mel_query,338duration_query=duration_query,339f0_query=f0_query,340energy_query=energy_query,341f0_stat=args.f0_stat,342energy_stat=args.energy_stat,343mel_length_threshold=mel_length_threshold,344).create(345is_shuffle=config["is_shuffle"],346allow_cache=config["allow_cache"],347batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,348)349350# define trainer351trainer = FastSpeech2Trainer(352config=config,353strategy=STRATEGY,354steps=0,355epochs=0,356is_mixed_precision=args.mixed_precision,357)358359with STRATEGY.scope():360# define model361fastspeech = TFFastSpeech2(362config=FastSpeech2Config(**config["fastspeech2_params"])363)364fastspeech._build()365fastspeech.summary()366if len(args.pretrained) > 1:367fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True)368logging.info(369f"Successfully loaded pretrained weight from {args.pretrained}."370)371372# AdamW for fastspeech373learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(374initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],375decay_steps=config["optimizer_params"]["decay_steps"],376end_learning_rate=config["optimizer_params"]["end_learning_rate"],377)378379learning_rate_fn = WarmUp(380initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],381decay_schedule_fn=learning_rate_fn,382warmup_steps=int(383config["train_max_steps"]384* config["optimizer_params"]["warmup_proportion"]385),386)387388optimizer = AdamWeightDecay(389learning_rate=learning_rate_fn,390weight_decay_rate=config["optimizer_params"]["weight_decay"],391beta_1=0.9,392beta_2=0.98,393epsilon=1e-6,394exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],395)396397_ = optimizer.iterations398399# compile trainer400trainer.compile(model=fastspeech, optimizer=optimizer)401402# start training403try:404trainer.fit(405train_dataset,406valid_dataset,407saved_path=os.path.join(config["outdir"], "checkpoints/"),408resume=args.resume,409)410except KeyboardInterrupt:411trainer.save_checkpoint()412logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")413414415if __name__ == "__main__":416main()417418419