Path: blob/master/examples/fastspeech2_libritts/train_fastspeech2.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Train FastSpeech2."""1516import tensorflow as tf1718physical_devices = tf.config.list_physical_devices("GPU")19for i in range(len(physical_devices)):20tf.config.experimental.set_memory_growth(physical_devices[i], True)2122import sys2324sys.path.append(".")2526import argparse27import logging28import os2930import numpy as np31import yaml32import json3334import tensorflow_tts35from examples.fastspeech2_libritts.fastspeech2_dataset import (36CharactorDurationF0EnergyMelDataset,37)38from tensorflow_tts.configs import FastSpeech2Config39from tensorflow_tts.models import TFFastSpeech240from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp41from tensorflow_tts.trainers import Seq2SeqBasedTrainer42from tensorflow_tts.utils import (43calculate_2d_loss,44calculate_3d_loss,45return_strategy,46TFGriffinLim,47)484950class FastSpeech2Trainer(Seq2SeqBasedTrainer):51"""FastSpeech2 Trainer class based on FastSpeechTrainer."""5253def __init__(54self,55config,56strategy,57steps=0,58epochs=0,59is_mixed_precision=False,60stats_path: str = "",61dataset_config: str = "",62):63"""Initialize trainer.64Args:65steps (int): Initial global steps.66epochs (int): Initial global epochs.67config (dict): Config dict loaded from yaml format configuration file.68is_mixed_precision (bool): Use mixed precision or not.69"""70super(FastSpeech2Trainer, self).__init__(71steps=steps,72epochs=epochs,73config=config,74strategy=strategy,75is_mixed_precision=is_mixed_precision,76)77# define metrics to aggregates data and use tf.summary logs them78self.list_metrics_name = [79"duration_loss",80"f0_loss",81"energy_loss",82"mel_loss_before",83"mel_loss_after",84]85self.init_train_eval_metrics(self.list_metrics_name)86self.reset_states_train()87self.reset_states_eval()88self.use_griffin = config.get("use_griffin", False)89self.griffin_lim_tf = None90if self.use_griffin:91logging.info(92f"Load griff stats from {stats_path} and config from {dataset_config}"93)94self.griff_conf = yaml.load(open(dataset_config), Loader=yaml.Loader)95self.prepare_grim(stats_path, self.griff_conf)9697def prepare_grim(self, stats_path, config):98if not stats_path:99raise KeyError("stats path need to exist")100self.griffin_lim_tf = TFGriffinLim(stats_path, config)101102def compile(self, model, optimizer):103super().compile(model, optimizer)104self.mse = tf.keras.losses.MeanSquaredError(105reduction=tf.keras.losses.Reduction.NONE106)107self.mae = tf.keras.losses.MeanAbsoluteError(108reduction=tf.keras.losses.Reduction.NONE109)110111def compute_per_example_losses(self, batch, outputs):112"""Compute per example losses and return dict_metrics_losses113Note that all element of the loss MUST has a shape [batch_size] and114the keys of dict_metrics_losses MUST be in self.list_metrics_name.115116Args:117batch: dictionary batch input return from dataloader118outputs: outputs of the model119120Returns:121per_example_losses: per example losses for each GPU, shape [B]122dict_metrics_losses: dictionary loss.123"""124mel_before, mel_after, duration_outputs, f0_outputs, energy_outputs = outputs125126log_duration = tf.math.log(127tf.cast(tf.math.add(batch["duration_gts"], 1), tf.float32)128)129duration_loss = calculate_2d_loss(log_duration, duration_outputs, self.mse)130f0_loss = calculate_2d_loss(batch["f0_gts"], f0_outputs, self.mse)131energy_loss = calculate_2d_loss(batch["energy_gts"], energy_outputs, self.mse)132mel_loss_before = calculate_3d_loss(batch["mel_gts"], mel_before, self.mae)133mel_loss_after = calculate_3d_loss(batch["mel_gts"], mel_after, self.mae)134135per_example_losses = (136duration_loss + f0_loss + energy_loss + mel_loss_before + mel_loss_after137)138139dict_metrics_losses = {140"duration_loss": duration_loss,141"f0_loss": f0_loss,142"energy_loss": energy_loss,143"mel_loss_before": mel_loss_before,144"mel_loss_after": mel_loss_after,145}146147return per_example_losses, dict_metrics_losses148149def generate_and_save_intermediate_result(self, batch):150"""Generate and save intermediate result."""151import matplotlib.pyplot as plt152153# predict with tf.function.154outputs = self.one_step_predict(batch)155156mels_before, mels_after, *_ = outputs157mel_gts = batch["mel_gts"]158utt_ids = batch["utt_ids"]159160# convert to tensor.161# here we just take a sample at first replica.162try:163mels_before = mels_before.values[0].numpy()164mels_after = mels_after.values[0].numpy()165mel_gts = mel_gts.values[0].numpy()166utt_ids = utt_ids.values[0].numpy()167except Exception:168mels_before = mels_before.numpy()169mels_after = mels_after.numpy()170mel_gts = mel_gts.numpy()171utt_ids = utt_ids.numpy()172173# check directory174if self.use_griffin:175griff_dir_name = os.path.join(176self.config["outdir"], f"predictions/{self.steps}_wav"177)178if not os.path.exists(griff_dir_name):179os.makedirs(griff_dir_name)180181dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")182if not os.path.exists(dirname):183os.makedirs(dirname)184185for idx, (mel_gt, mel_before, mel_after) in enumerate(186zip(mel_gts, mels_before, mels_after), 0187):188189if self.use_griffin:190utt_id = utt_ids[idx]191grif_before = self.griffin_lim_tf(192tf.reshape(mel_before, [-1, 80])[tf.newaxis, :], n_iter=32193)194grif_after = self.griffin_lim_tf(195tf.reshape(mel_after, [-1, 80])[tf.newaxis, :], n_iter=32196)197grif_gt = self.griffin_lim_tf(198tf.reshape(mel_gt, [-1, 80])[tf.newaxis, :], n_iter=32199)200self.griffin_lim_tf.save_wav(201grif_before, griff_dir_name, f"{utt_id}_before"202)203self.griffin_lim_tf.save_wav(204grif_after, griff_dir_name, f"{utt_id}_after"205)206self.griffin_lim_tf.save_wav(grif_gt, griff_dir_name, f"{utt_id}_gt")207208utt_id = utt_ids[idx]209mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80]210mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80]211mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80]212213# plit figure and save it214figname = os.path.join(dirname, f"{utt_id}.png")215fig = plt.figure(figsize=(10, 8))216ax1 = fig.add_subplot(311)217ax2 = fig.add_subplot(312)218ax3 = fig.add_subplot(313)219im = ax1.imshow(np.rot90(mel_gt), aspect="auto", interpolation="none")220ax1.set_title("Target Mel-Spectrogram")221fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)222ax2.set_title("Predicted Mel-before-Spectrogram")223im = ax2.imshow(np.rot90(mel_before), aspect="auto", interpolation="none")224fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)225ax3.set_title("Predicted Mel-after-Spectrogram")226im = ax3.imshow(np.rot90(mel_after), aspect="auto", interpolation="none")227fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax3)228plt.tight_layout()229plt.savefig(figname)230plt.close()231232233def main():234"""Run training process."""235parser = argparse.ArgumentParser(236description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)"237)238parser.add_argument(239"--train-dir",240default="dump/train",241type=str,242help="directory including training data. ",243)244parser.add_argument(245"--dev-dir",246default="dump/valid",247type=str,248help="directory including development data. ",249)250parser.add_argument(251"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."252)253parser.add_argument(254"--f0-stat", default="./dump/stats_f0.npy", type=str, help="f0-stat path.",255)256parser.add_argument(257"--energy-stat",258default="./dump/stats_energy.npy",259type=str,260help="energy-stat path.",261)262parser.add_argument(263"--outdir", type=str, required=True, help="directory to save checkpoints."264)265parser.add_argument(266"--config", type=str, required=True, help="yaml format configuration file."267)268parser.add_argument(269"--resume",270default="",271type=str,272nargs="?",273help='checkpoint file path to resume training. (default="")',274)275parser.add_argument(276"--verbose",277type=int,278default=1,279help="logging level. higher is more logging. (default=1)",280)281parser.add_argument(282"--mixed_precision",283default=1,284type=int,285help="using mixed precision for generator or not.",286)287parser.add_argument(288"--dataset_config", default="preprocess/libritts_preprocess.yaml", type=str,289)290parser.add_argument(291"--dataset_stats", default="dump/stats.npy", type=str,292)293parser.add_argument(294"--dataset_mapping", default="dump/libritts_mapper.npy", type=str,295)296parser.add_argument(297"--pretrained",298default="",299type=str,300nargs="?",301help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",302)303args = parser.parse_args()304305# return strategy306STRATEGY = return_strategy()307308# set mixed precision config309if args.mixed_precision == 1:310tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})311312args.mixed_precision = bool(args.mixed_precision)313args.use_norm = bool(args.use_norm)314315# set logger316if args.verbose > 1:317logging.basicConfig(318level=logging.DEBUG,319stream=sys.stdout,320format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",321)322elif args.verbose > 0:323logging.basicConfig(324level=logging.INFO,325stream=sys.stdout,326format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",327)328else:329logging.basicConfig(330level=logging.WARN,331stream=sys.stdout,332format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",333)334logging.warning("Skip DEBUG/INFO messages")335336# check directory existence337if not os.path.exists(args.outdir):338os.makedirs(args.outdir)339340# check arguments341if args.train_dir is None:342raise ValueError("Please specify --train-dir")343if args.dev_dir is None:344raise ValueError("Please specify --valid-dir")345346# load and save config347with open(args.config) as f:348config = yaml.load(f, Loader=yaml.Loader)349config.update(vars(args))350config["version"] = tensorflow_tts.__version__351with open(os.path.join(args.outdir, "config.yml"), "w") as f:352yaml.dump(config, f, Dumper=yaml.Dumper)353for key, value in config.items():354logging.info(f"{key} = {value}")355356# get dataset357if config["remove_short_samples"]:358mel_length_threshold = config["mel_length_threshold"]359else:360mel_length_threshold = None361362if config["format"] == "npy":363charactor_query = "*-ids.npy"364mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"365duration_query = "*-durations.npy"366f0_query = "*-raw-f0.npy"367energy_query = "*-raw-energy.npy"368else:369raise ValueError("Only npy are supported.")370371# load speakers map from dataset map372with open(args.dataset_mapping) as f:373dataset_mapping = json.load(f)374speakers_map = dataset_mapping["speakers_map"]375376# Check n_speakers matches number of speakers in speakers_map377n_speakers = config["fastspeech2_params"]["n_speakers"]378assert n_speakers == len(379speakers_map380), f"Number of speakers in dataset does not match n_speakers in config"381382# define train/valid dataset383train_dataset = CharactorDurationF0EnergyMelDataset(384root_dir=args.train_dir,385charactor_query=charactor_query,386mel_query=mel_query,387duration_query=duration_query,388f0_query=f0_query,389energy_query=energy_query,390f0_stat=args.f0_stat,391energy_stat=args.energy_stat,392mel_length_threshold=mel_length_threshold,393speakers_map=speakers_map,394).create(395is_shuffle=config["is_shuffle"],396allow_cache=config["allow_cache"],397batch_size=config["batch_size"]398* STRATEGY.num_replicas_in_sync399* config["gradient_accumulation_steps"],400)401402valid_dataset = CharactorDurationF0EnergyMelDataset(403root_dir=args.dev_dir,404charactor_query=charactor_query,405mel_query=mel_query,406duration_query=duration_query,407f0_query=f0_query,408energy_query=energy_query,409f0_stat=args.f0_stat,410energy_stat=args.energy_stat,411mel_length_threshold=mel_length_threshold,412speakers_map=speakers_map,413).create(414is_shuffle=config["is_shuffle"],415allow_cache=config["allow_cache"],416batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,417)418419# define trainer420trainer = FastSpeech2Trainer(421config=config,422strategy=STRATEGY,423steps=0,424epochs=0,425is_mixed_precision=args.mixed_precision,426stats_path=args.dataset_stats,427dataset_config=args.dataset_config,428)429430with STRATEGY.scope():431# define model432fastspeech = TFFastSpeech2(433config=FastSpeech2Config(**config["fastspeech2_params"])434)435fastspeech._build()436fastspeech.summary()437438if len(args.pretrained) > 1:439fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True)440logging.info(441f"Successfully loaded pretrained weight from {args.pretrained}."442)443444# AdamW for fastspeech445learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(446initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],447decay_steps=config["optimizer_params"]["decay_steps"],448end_learning_rate=config["optimizer_params"]["end_learning_rate"],449)450451learning_rate_fn = WarmUp(452initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],453decay_schedule_fn=learning_rate_fn,454warmup_steps=int(455config["train_max_steps"]456* config["optimizer_params"]["warmup_proportion"]457),458)459460optimizer = AdamWeightDecay(461learning_rate=learning_rate_fn,462weight_decay_rate=config["optimizer_params"]["weight_decay"],463beta_1=0.9,464beta_2=0.98,465epsilon=1e-6,466exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],467)468469_ = optimizer.iterations470471# compile trainer472trainer.compile(model=fastspeech, optimizer=optimizer)473474# start training475try:476trainer.fit(477train_dataset,478valid_dataset,479saved_path=os.path.join(config["outdir"], "checkpoints/"),480resume=args.resume,481)482except KeyboardInterrupt:483trainer.save_checkpoint()484logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")485486487if __name__ == "__main__":488main()489490491