Path: blob/master/examples/tacotron2/train_tacotron2.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Train Tacotron2."""15import tensorflow as tf1617physical_devices = tf.config.list_physical_devices("GPU")18for i in range(len(physical_devices)):19tf.config.experimental.set_memory_growth(physical_devices[i], True)2021import sys2223sys.path.append(".")2425import argparse26import logging27import os2829import numpy as np30import yaml31from tqdm import tqdm3233import tensorflow_tts34from examples.tacotron2.tacotron_dataset import CharactorMelDataset35from tensorflow_tts.configs.tacotron2 import Tacotron2Config36from tensorflow_tts.models import TFTacotron237from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp38from tensorflow_tts.trainers import Seq2SeqBasedTrainer39from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy404142class Tacotron2Trainer(Seq2SeqBasedTrainer):43"""Tacotron2 Trainer class based on Seq2SeqBasedTrainer."""4445def __init__(46self,47config,48strategy,49steps=0,50epochs=0,51is_mixed_precision=False,52):53"""Initialize trainer.5455Args:56steps (int): Initial global steps.57epochs (int): Initial global epochs.58config (dict): Config dict loaded from yaml format configuration file.59is_mixed_precision (bool): Use mixed precision or not.6061"""62super(Tacotron2Trainer, self).__init__(63steps=steps,64epochs=epochs,65config=config,66strategy=strategy,67is_mixed_precision=is_mixed_precision,68)69# define metrics to aggregates data and use tf.summary logs them70self.list_metrics_name = [71"stop_token_loss",72"mel_loss_before",73"mel_loss_after",74"guided_attention_loss",75]76self.init_train_eval_metrics(self.list_metrics_name)77self.reset_states_train()78self.reset_states_eval()7980self.config = config8182def compile(self, model, optimizer):83super().compile(model, optimizer)84self.binary_crossentropy = tf.keras.losses.BinaryCrossentropy(85from_logits=True, reduction=tf.keras.losses.Reduction.NONE86)87self.mse = tf.keras.losses.MeanSquaredError(88reduction=tf.keras.losses.Reduction.NONE89)90self.mae = tf.keras.losses.MeanAbsoluteError(91reduction=tf.keras.losses.Reduction.NONE92)9394def _train_step(self, batch):95"""Here we re-define _train_step because apply input_signature make96the training progress slower on my experiment. Note that input_signature97is apply on based_trainer by default.98"""99if self._already_apply_input_signature is False:100self.one_step_forward = tf.function(101self._one_step_forward, experimental_relax_shapes=True102)103self.one_step_evaluate = tf.function(104self._one_step_evaluate, experimental_relax_shapes=True105)106self.one_step_predict = tf.function(107self._one_step_predict, experimental_relax_shapes=True108)109self._already_apply_input_signature = True110111# run one_step_forward112self.one_step_forward(batch)113114# update counts115self.steps += 1116self.tqdm.update(1)117self._check_train_finish()118119def _one_step_evaluate_per_replica(self, batch):120"""One step evaluate per GPU121122Tacotron-2 used teacher-forcing when training and evaluation.123So we need pass `training=True` for inference step.124125"""126outputs = self._model(**batch, training=True)127_, dict_metrics_losses = self.compute_per_example_losses(batch, outputs)128129self.update_eval_metrics(dict_metrics_losses)130131def _one_step_predict_per_replica(self, batch):132"""One step predict per GPU133134Tacotron-2 used teacher-forcing when training and evaluation.135So we need pass `training=True` for inference step.136137"""138outputs = self._model(**batch, training=True)139return outputs140141def compute_per_example_losses(self, batch, outputs):142"""Compute per example losses and return dict_metrics_losses143Note that all element of the loss MUST has a shape [batch_size] and144the keys of dict_metrics_losses MUST be in self.list_metrics_name.145146Args:147batch: dictionary batch input return from dataloader148outputs: outputs of the model149150Returns:151per_example_losses: per example losses for each GPU, shape [B]152dict_metrics_losses: dictionary loss.153"""154(155decoder_output,156post_mel_outputs,157stop_token_predictions,158alignment_historys,159) = outputs160161mel_loss_before = calculate_3d_loss(162batch["mel_gts"], decoder_output, loss_fn=self.mae163)164mel_loss_after = calculate_3d_loss(165batch["mel_gts"], post_mel_outputs, loss_fn=self.mae166)167168# calculate stop_loss169max_mel_length = (170tf.reduce_max(batch["mel_lengths"])171if self.config["use_fixed_shapes"] is False172else [self.config["max_mel_length"]]173)174stop_gts = tf.expand_dims(175tf.range(tf.reduce_max(max_mel_length), dtype=tf.int32), 0176) # [1, max_len]177stop_gts = tf.tile(178stop_gts, [tf.shape(batch["mel_lengths"])[0], 1]179) # [B, max_len]180stop_gts = tf.cast(181tf.math.greater_equal(stop_gts, tf.expand_dims(batch["mel_lengths"], 1)),182tf.float32,183)184185stop_token_loss = calculate_2d_loss(186stop_gts, stop_token_predictions, loss_fn=self.binary_crossentropy187)188189# calculate guided attention loss.190attention_masks = tf.cast(191tf.math.not_equal(batch["g_attentions"], -1.0), tf.float32192)193loss_att = tf.reduce_sum(194tf.abs(alignment_historys * batch["g_attentions"]) * attention_masks,195axis=[1, 2],196)197loss_att /= tf.reduce_sum(attention_masks, axis=[1, 2])198199per_example_losses = (200stop_token_loss + mel_loss_before + mel_loss_after + loss_att201)202203dict_metrics_losses = {204"stop_token_loss": stop_token_loss,205"mel_loss_before": mel_loss_before,206"mel_loss_after": mel_loss_after,207"guided_attention_loss": loss_att,208}209210return per_example_losses, dict_metrics_losses211212def generate_and_save_intermediate_result(self, batch):213"""Generate and save intermediate result."""214import matplotlib.pyplot as plt215216# predict with tf.function for faster.217outputs = self.one_step_predict(batch)218(219decoder_output,220mel_outputs,221stop_token_predictions,222alignment_historys,223) = outputs224mel_gts = batch["mel_gts"]225utt_ids = batch["utt_ids"]226227# convert to tensor.228# here we just take a sample at first replica.229try:230mels_before = decoder_output.values[0].numpy()231mels_after = mel_outputs.values[0].numpy()232mel_gts = mel_gts.values[0].numpy()233alignment_historys = alignment_historys.values[0].numpy()234utt_ids = utt_ids.values[0].numpy()235except Exception:236mels_before = decoder_output.numpy()237mels_after = mel_outputs.numpy()238mel_gts = mel_gts.numpy()239alignment_historys = alignment_historys.numpy()240utt_ids = utt_ids.numpy()241242# check directory243dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps")244if not os.path.exists(dirname):245os.makedirs(dirname)246247for idx, (mel_gt, mel_before, mel_after, alignment_history) in enumerate(248zip(mel_gts, mels_before, mels_after, alignment_historys), 0249):250mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80]251mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80]252mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80]253254# plot figure and save it255utt_id = utt_ids[idx]256figname = os.path.join(dirname, f"{utt_id}.png")257fig = plt.figure(figsize=(10, 8))258ax1 = fig.add_subplot(311)259ax2 = fig.add_subplot(312)260ax3 = fig.add_subplot(313)261im = ax1.imshow(np.rot90(mel_gt), aspect="auto", interpolation="none")262ax1.set_title("Target Mel-Spectrogram")263fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)264ax2.set_title(f"Predicted Mel-before-Spectrogram @ {self.steps} steps")265im = ax2.imshow(np.rot90(mel_before), aspect="auto", interpolation="none")266fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)267ax3.set_title(f"Predicted Mel-after-Spectrogram @ {self.steps} steps")268im = ax3.imshow(np.rot90(mel_after), aspect="auto", interpolation="none")269fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax3)270plt.tight_layout()271plt.savefig(figname)272plt.close()273274# plot alignment275figname = os.path.join(dirname, f"{idx}_alignment.png")276fig = plt.figure(figsize=(8, 6))277ax = fig.add_subplot(111)278ax.set_title(f"Alignment @ {self.steps} steps")279im = ax.imshow(280alignment_history, aspect="auto", origin="lower", interpolation="none"281)282fig.colorbar(im, ax=ax)283xlabel = "Decoder timestep"284plt.xlabel(xlabel)285plt.ylabel("Encoder timestep")286plt.tight_layout()287plt.savefig(figname)288plt.close()289290291def main():292"""Run training process."""293parser = argparse.ArgumentParser(294description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)"295)296parser.add_argument(297"--train-dir",298default=None,299type=str,300help="directory including training data. ",301)302parser.add_argument(303"--dev-dir",304default=None,305type=str,306help="directory including development data. ",307)308parser.add_argument(309"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."310)311parser.add_argument(312"--outdir", type=str, required=True, help="directory to save checkpoints."313)314parser.add_argument(315"--config", type=str, required=True, help="yaml format configuration file."316)317parser.add_argument(318"--resume",319default="",320type=str,321nargs="?",322help='checkpoint file path to resume training. (default="")',323)324parser.add_argument(325"--verbose",326type=int,327default=1,328help="logging level. higher is more logging. (default=1)",329)330parser.add_argument(331"--mixed_precision",332default=0,333type=int,334help="using mixed precision for generator or not.",335)336parser.add_argument(337"--pretrained",338default="",339type=str,340nargs="?",341help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",342)343parser.add_argument(344"--use-fal",345default=0,346type=int,347help="Use forced alignment guided attention loss or regular",348)349args = parser.parse_args()350351# return strategy352STRATEGY = return_strategy()353354# set mixed precision config355if args.mixed_precision == 1:356tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})357358args.mixed_precision = bool(args.mixed_precision)359args.use_norm = bool(args.use_norm)360args.use_fal = bool(args.use_fal)361362# set logger363if args.verbose > 1:364logging.basicConfig(365level=logging.DEBUG,366stream=sys.stdout,367format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",368)369elif args.verbose > 0:370logging.basicConfig(371level=logging.INFO,372stream=sys.stdout,373format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",374)375else:376logging.basicConfig(377level=logging.WARN,378stream=sys.stdout,379format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",380)381logging.warning("Skip DEBUG/INFO messages")382383# check directory existence384if not os.path.exists(args.outdir):385os.makedirs(args.outdir)386387# check arguments388if args.train_dir is None:389raise ValueError("Please specify --train-dir")390if args.dev_dir is None:391raise ValueError("Please specify --valid-dir")392393# load and save config394with open(args.config) as f:395config = yaml.load(f, Loader=yaml.Loader)396config.update(vars(args))397config["version"] = tensorflow_tts.__version__398399# get dataset400if config["remove_short_samples"]:401mel_length_threshold = config["mel_length_threshold"]402else:403mel_length_threshold = 0404405if config["format"] == "npy":406charactor_query = "*-ids.npy"407mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"408align_query = "*-alignment.npy" if args.use_fal is True else ""409charactor_load_fn = np.load410mel_load_fn = np.load411else:412raise ValueError("Only npy are supported.")413414train_dataset = CharactorMelDataset(415dataset=config["tacotron2_params"]["dataset"],416root_dir=args.train_dir,417charactor_query=charactor_query,418mel_query=mel_query,419charactor_load_fn=charactor_load_fn,420mel_load_fn=mel_load_fn,421mel_length_threshold=mel_length_threshold,422reduction_factor=config["tacotron2_params"]["reduction_factor"],423use_fixed_shapes=config["use_fixed_shapes"],424align_query=align_query,425)426427# update max_mel_length and max_char_length to config428config.update({"max_mel_length": int(train_dataset.max_mel_length)})429config.update({"max_char_length": int(train_dataset.max_char_length)})430431with open(os.path.join(args.outdir, "config.yml"), "w") as f:432yaml.dump(config, f, Dumper=yaml.Dumper)433for key, value in config.items():434logging.info(f"{key} = {value}")435436train_dataset = train_dataset.create(437is_shuffle=config["is_shuffle"],438allow_cache=config["allow_cache"],439batch_size=config["batch_size"]440* STRATEGY.num_replicas_in_sync441* config["gradient_accumulation_steps"],442)443444valid_dataset = CharactorMelDataset(445dataset=config["tacotron2_params"]["dataset"],446root_dir=args.dev_dir,447charactor_query=charactor_query,448mel_query=mel_query,449charactor_load_fn=charactor_load_fn,450mel_load_fn=mel_load_fn,451mel_length_threshold=mel_length_threshold,452reduction_factor=config["tacotron2_params"]["reduction_factor"],453use_fixed_shapes=False, # don't need apply fixed shape for evaluation.454align_query=align_query,455).create(456is_shuffle=config["is_shuffle"],457allow_cache=config["allow_cache"],458batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,459)460461# define trainer462trainer = Tacotron2Trainer(463config=config,464strategy=STRATEGY,465steps=0,466epochs=0,467is_mixed_precision=args.mixed_precision,468)469470with STRATEGY.scope():471# define model.472tacotron_config = Tacotron2Config(**config["tacotron2_params"])473tacotron2 = TFTacotron2(config=tacotron_config, name="tacotron2")474tacotron2._build()475tacotron2.summary()476477if len(args.pretrained) > 1:478tacotron2.load_weights(args.pretrained, by_name=True, skip_mismatch=True)479logging.info(480f"Successfully loaded pretrained weight from {args.pretrained}."481)482483# AdamW for tacotron2484learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(485initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],486decay_steps=config["optimizer_params"]["decay_steps"],487end_learning_rate=config["optimizer_params"]["end_learning_rate"],488)489490learning_rate_fn = WarmUp(491initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],492decay_schedule_fn=learning_rate_fn,493warmup_steps=int(494config["train_max_steps"]495* config["optimizer_params"]["warmup_proportion"]496),497)498499optimizer = AdamWeightDecay(500learning_rate=learning_rate_fn,501weight_decay_rate=config["optimizer_params"]["weight_decay"],502beta_1=0.9,503beta_2=0.98,504epsilon=1e-6,505exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],506)507508_ = optimizer.iterations509510# compile trainer511trainer.compile(model=tacotron2, optimizer=optimizer)512513# start training514try:515trainer.fit(516train_dataset,517valid_dataset,518saved_path=os.path.join(config["outdir"], "checkpoints/"),519resume=args.resume,520)521except KeyboardInterrupt:522trainer.save_checkpoint()523logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")524525526if __name__ == "__main__":527main()528529530