Path: blob/master/examples/fastspeech/train_fastspeech.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Train FastSpeech."""1516import tensorflow as tf1718physical_devices = tf.config.list_physical_devices("GPU")19for i in range(len(physical_devices)):20tf.config.experimental.set_memory_growth(physical_devices[i], True)2122import argparse23import logging24import os25import sys2627sys.path.append(".")2829import numpy as np30import yaml3132import tensorflow_tts33import tensorflow_tts.configs.fastspeech as FASTSPEECH_CONFIG34from examples.fastspeech.fastspeech_dataset import CharactorDurationMelDataset35from tensorflow_tts.models import TFFastSpeech36from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp37from tensorflow_tts.trainers import Seq2SeqBasedTrainer38from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy394041class FastSpeechTrainer(Seq2SeqBasedTrainer):42"""FastSpeech Trainer class based on Seq2SeqBasedTrainer."""4344def __init__(45self, config, strategy, steps=0, epochs=0, is_mixed_precision=False,46):47"""Initialize trainer.4849Args:50steps (int): Initial global steps.51epochs (int): Initial global epochs.52config (dict): Config dict loaded from yaml format configuration file.53is_mixed_precision (bool): Use mixed precision or not.5455"""56super(FastSpeechTrainer, self).__init__(57steps=steps,58epochs=epochs,59config=config,60strategy=strategy,61is_mixed_precision=is_mixed_precision,62)63# define metrics to aggregates data and use tf.summary logs them64self.list_metrics_name = ["duration_loss", "mel_loss_before", "mel_loss_after"]65self.init_train_eval_metrics(self.list_metrics_name)66self.reset_states_train()67self.reset_states_eval()6869self.config = config7071def compile(self, model, optimizer):72super().compile(model, optimizer)73self.mse = tf.keras.losses.MeanSquaredError(74reduction=tf.keras.losses.Reduction.NONE75)76self.mae = tf.keras.losses.MeanAbsoluteError(77reduction=tf.keras.losses.Reduction.NONE78)7980def compute_per_example_losses(self, batch, outputs):81"""Compute per example losses and return dict_metrics_losses82Note that all element of the loss MUST has a shape [batch_size] and83the keys of dict_metrics_losses MUST be in self.list_metrics_name.8485Args:86batch: dictionary batch input return from dataloader87outputs: outputs of the model8889Returns:90per_example_losses: per example losses for each GPU, shape [B]91dict_metrics_losses: dictionary loss.92"""93mel_before, mel_after, duration_outputs = outputs9495log_duration = tf.math.log(96tf.cast(tf.math.add(batch["duration_gts"], 1), tf.float32)97)98duration_loss = self.mse(log_duration, duration_outputs)99mel_loss_before = calculate_3d_loss(batch["mel_gts"], mel_before, self.mae)100mel_loss_after = calculate_3d_loss(batch["mel_gts"], mel_after, self.mae)101102per_example_losses = duration_loss + mel_loss_before + mel_loss_after103104dict_metrics_losses = {105"duration_loss": duration_loss,106"mel_loss_before": mel_loss_before,107"mel_loss_after": mel_loss_after,108}109110return per_example_losses, dict_metrics_losses111112def generate_and_save_intermediate_result(self, batch):113"""Generate and save intermediate result."""114import matplotlib.pyplot as plt115116# predict with tf.function.117outputs = self.one_step_predict(batch)118119mels_before, mels_after, *_ = outputs120mel_gts = batch["mel_gts"]121utt_ids = batch["utt_ids"]122123# convert to tensor.124# here we just take a sample at first replica.125try:126mels_before = mels_before.values[0].numpy()127mels_after = mels_after.values[0].numpy()128mel_gts = mel_gts.values[0].numpy()129utt_ids = utt_ids.values[0].numpy()130except Exception:131mels_before = mels_before.numpy()132mels_after = mels_after.numpy()133mel_gts = mel_gts.numpy()134utt_ids = utt_ids.numpy()135136# check directory137dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")138if not os.path.exists(dirname):139os.makedirs(dirname)140141for idx, (mel_gt, mel_before, mel_after) in enumerate(142zip(mel_gts, mels_before, mels_after), 0143):144mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80]145mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80]146mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80]147148# plit figure and save it149utt_id = utt_ids[idx].decode("utf-8")150figname = os.path.join(dirname, f"{utt_id}.png")151fig = plt.figure(figsize=(10, 8))152ax1 = fig.add_subplot(311)153ax2 = fig.add_subplot(312)154ax3 = fig.add_subplot(313)155im = ax1.imshow(np.rot90(mel_gt), aspect="auto", interpolation="none")156ax1.set_title("Target Mel-Spectrogram")157fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)158ax2.set_title("Predicted Mel-before-Spectrogram")159im = ax2.imshow(np.rot90(mel_before), aspect="auto", interpolation="none")160fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)161ax3.set_title("Predicted Mel-after-Spectrogram")162im = ax3.imshow(np.rot90(mel_after), aspect="auto", interpolation="none")163fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax3)164plt.tight_layout()165plt.savefig(figname)166plt.close()167168169def main():170"""Run training process."""171parser = argparse.ArgumentParser(172description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)"173)174parser.add_argument(175"--train-dir",176default=None,177type=str,178help="directory including training data. ",179)180parser.add_argument(181"--dev-dir",182default=None,183type=str,184help="directory including development data. ",185)186parser.add_argument(187"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."188)189parser.add_argument(190"--outdir", type=str, required=True, help="directory to save checkpoints."191)192parser.add_argument(193"--config", type=str, required=True, help="yaml format configuration file."194)195parser.add_argument(196"--resume",197default="",198type=str,199nargs="?",200help='checkpoint file path to resume training. (default="")',201)202parser.add_argument(203"--verbose",204type=int,205default=1,206help="logging level. higher is more logging. (default=1)",207)208parser.add_argument(209"--mixed_precision",210default=0,211type=int,212help="using mixed precision for generator or not.",213)214parser.add_argument(215"--pretrained",216default="",217type=str,218nargs="?",219help="pretrained checkpoint file to load weights from. Auto-skips non-matching layers",220)221args = parser.parse_args()222223# return strategy224STRATEGY = return_strategy()225226# set mixed precision config227if args.mixed_precision == 1:228tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})229230args.mixed_precision = bool(args.mixed_precision)231args.use_norm = bool(args.use_norm)232233# set logger234if args.verbose > 1:235logging.basicConfig(236level=logging.DEBUG,237stream=sys.stdout,238format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",239)240elif args.verbose > 0:241logging.basicConfig(242level=logging.INFO,243stream=sys.stdout,244format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",245)246else:247logging.basicConfig(248level=logging.WARN,249stream=sys.stdout,250format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",251)252logging.warning("Skip DEBUG/INFO messages")253254# check directory existence255if not os.path.exists(args.outdir):256os.makedirs(args.outdir)257258# check arguments259if args.train_dir is None:260raise ValueError("Please specify --train-dir")261if args.dev_dir is None:262raise ValueError("Please specify --valid-dir")263264# load and save config265with open(args.config) as f:266config = yaml.load(f, Loader=yaml.Loader)267config.update(vars(args))268config["version"] = tensorflow_tts.__version__269with open(os.path.join(args.outdir, "config.yml"), "w") as f:270yaml.dump(config, f, Dumper=yaml.Dumper)271for key, value in config.items():272logging.info(f"{key} = {value}")273274# get dataset275if config["remove_short_samples"]:276mel_length_threshold = config["mel_length_threshold"]277else:278mel_length_threshold = None279280if config["format"] == "npy":281charactor_query = "*-ids.npy"282mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"283duration_query = "*-durations.npy"284charactor_load_fn = np.load285mel_load_fn = np.load286duration_load_fn = np.load287else:288raise ValueError("Only npy are supported.")289290# define train/valid dataset291train_dataset = CharactorDurationMelDataset(292root_dir=args.train_dir,293charactor_query=charactor_query,294mel_query=mel_query,295duration_query=duration_query,296charactor_load_fn=charactor_load_fn,297mel_load_fn=mel_load_fn,298duration_load_fn=duration_load_fn,299mel_length_threshold=mel_length_threshold,300).create(301is_shuffle=config["is_shuffle"],302allow_cache=config["allow_cache"],303batch_size=config["batch_size"]304* STRATEGY.num_replicas_in_sync305* config["gradient_accumulation_steps"],306)307308valid_dataset = CharactorDurationMelDataset(309root_dir=args.dev_dir,310charactor_query=charactor_query,311mel_query=mel_query,312duration_query=duration_query,313charactor_load_fn=charactor_load_fn,314mel_load_fn=mel_load_fn,315duration_load_fn=duration_load_fn,316).create(317is_shuffle=config["is_shuffle"],318allow_cache=config["allow_cache"],319batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,320)321322# define trainer323trainer = FastSpeechTrainer(324config=config,325strategy=STRATEGY,326steps=0,327epochs=0,328is_mixed_precision=args.mixed_precision,329)330331with STRATEGY.scope():332# define model333fastspeech = TFFastSpeech(334config=FASTSPEECH_CONFIG.FastSpeechConfig(**config["fastspeech_params"])335)336fastspeech._build()337fastspeech.summary()338339if len(args.pretrained) > 1:340fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True)341logging.info(342f"Successfully loaded pretrained weight from {args.pretrained}."343)344345# AdamW for fastspeech346learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(347initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],348decay_steps=config["optimizer_params"]["decay_steps"],349end_learning_rate=config["optimizer_params"]["end_learning_rate"],350)351352learning_rate_fn = WarmUp(353initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],354decay_schedule_fn=learning_rate_fn,355warmup_steps=int(356config["train_max_steps"]357* config["optimizer_params"]["warmup_proportion"]358),359)360361optimizer = AdamWeightDecay(362learning_rate=learning_rate_fn,363weight_decay_rate=config["optimizer_params"]["weight_decay"],364beta_1=0.9,365beta_2=0.98,366epsilon=1e-6,367exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],368)369370_ = optimizer.iterations371372# compile trainer373trainer.compile(model=fastspeech, optimizer=optimizer)374375# start training376try:377trainer.fit(378train_dataset,379valid_dataset,380saved_path=os.path.join(config["outdir"], "checkpoints/"),381resume=args.resume,382)383except KeyboardInterrupt:384trainer.save_checkpoint()385logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")386387388if __name__ == "__main__":389main()390391392