Path: blob/master/examples/tacotron2/extract_duration.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Extract durations based-on tacotron-2 alignments for FastSpeech."""1516import argparse17import logging18import os19from numba import jit20import sys2122sys.path.append(".")2324import matplotlib.pyplot as plt25import numpy as np26import tensorflow as tf27import yaml28from tqdm import tqdm2930from examples.tacotron2.tacotron_dataset import CharactorMelDataset31from tensorflow_tts.configs import Tacotron2Config32from tensorflow_tts.models import TFTacotron2333435@jit(nopython=True)36def get_duration_from_alignment(alignment):37D = np.array([0 for _ in range(np.shape(alignment)[0])])3839for i in range(np.shape(alignment)[1]):40max_index = list(alignment[:, i]).index(alignment[:, i].max())41D[max_index] = D[max_index] + 14243return D444546def main():47"""Running extract tacotron-2 durations."""48parser = argparse.ArgumentParser(49description="Extract durations from charactor with trained Tacotron-2 "50"(See detail in tensorflow_tts/example/tacotron-2/extract_duration.py)."51)52parser.add_argument(53"--rootdir",54default=None,55type=str,56required=True,57help="directory including ids/durations files.",58)59parser.add_argument(60"--outdir", type=str, required=True, help="directory to save generated speech."61)62parser.add_argument(63"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."64)65parser.add_argument(66"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."67)68parser.add_argument("--batch-size", default=8, type=int, help="batch size.")69parser.add_argument("--win-front", default=2, type=int, help="win-front.")70parser.add_argument("--win-back", default=2, type=int, help="win-front.")71parser.add_argument(72"--use-window-mask", default=1, type=int, help="toggle window masking."73)74parser.add_argument("--save-alignment", default=0, type=int, help="save-alignment.")75parser.add_argument(76"--config",77default=None,78type=str,79required=True,80help="yaml format configuration file. if not explicitly provided, "81"it will be searched in the checkpoint directory. (default=None)",82)83parser.add_argument(84"--verbose",85type=int,86default=1,87help="logging level. higher is more logging. (default=1)",88)89args = parser.parse_args()9091# set logger92if args.verbose > 1:93logging.basicConfig(94level=logging.DEBUG,95format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",96)97elif args.verbose > 0:98logging.basicConfig(99level=logging.INFO,100format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",101)102else:103logging.basicConfig(104level=logging.WARN,105format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",106)107logging.warning("Skip DEBUG/INFO messages")108109# check directory existence110if not os.path.exists(args.outdir):111os.makedirs(args.outdir)112113# load config114with open(args.config) as f:115config = yaml.load(f, Loader=yaml.Loader)116config.update(vars(args))117118if config["format"] == "npy":119char_query = "*-ids.npy"120mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"121char_load_fn = np.load122mel_load_fn = np.load123else:124raise ValueError("Only npy is supported.")125126# define data-loader127dataset = CharactorMelDataset(128dataset=config["tacotron2_params"]["dataset"],129root_dir=args.rootdir,130charactor_query=char_query,131mel_query=mel_query,132charactor_load_fn=char_load_fn,133mel_load_fn=mel_load_fn,134reduction_factor=config["tacotron2_params"]["reduction_factor"],135use_fixed_shapes=True,136)137dataset = dataset.create(allow_cache=True, batch_size=args.batch_size, drop_remainder=False)138139# define model and load checkpoint140tacotron2 = TFTacotron2(141config=Tacotron2Config(**config["tacotron2_params"]),142name="tacotron2",143)144tacotron2._build() # build model to be able load_weights.145tacotron2.load_weights(args.checkpoint)146147# apply tf.function for tacotron2.148tacotron2 = tf.function(tacotron2, experimental_relax_shapes=True)149150for data in tqdm(dataset, desc="[Extract Duration]"):151utt_ids = data["utt_ids"]152input_lengths = data["input_lengths"]153mel_lengths = data["mel_lengths"]154utt_ids = utt_ids.numpy()155real_mel_lengths = data["real_mel_lengths"]156del data["real_mel_lengths"]157158# tacotron2 inference.159mel_outputs, post_mel_outputs, stop_outputs, alignment_historys = tacotron2(160**data,161use_window_mask=args.use_window_mask,162win_front=args.win_front,163win_back=args.win_back,164training=True,165)166167# convert to numpy168alignment_historys = alignment_historys.numpy()169170for i, alignment in enumerate(alignment_historys):171real_char_length = input_lengths[i].numpy()172real_mel_length = real_mel_lengths[i].numpy()173alignment_mel_length = int(174np.ceil(175real_mel_length / config["tacotron2_params"]["reduction_factor"]176)177)178alignment = alignment[:real_char_length, :alignment_mel_length]179d = get_duration_from_alignment(alignment) # [max_char_len]180181d = d * config["tacotron2_params"]["reduction_factor"]182assert (183np.sum(d) >= real_mel_length184), f"{d}, {np.sum(d)}, {alignment_mel_length}, {real_mel_length}"185if np.sum(d) > real_mel_length:186rest = np.sum(d) - real_mel_length187# print(d, np.sum(d), real_mel_length)188if d[-1] > rest:189d[-1] -= rest190elif d[0] > rest:191d[0] -= rest192else:193d[-1] -= rest // 2194d[0] -= rest - rest // 2195196assert d[-1] >= 0 and d[0] >= 0, f"{d}, {np.sum(d)}, {real_mel_length}"197198saved_name = utt_ids[i].decode("utf-8")199200# check a length compatible201assert (202len(d) == real_char_length203), f"different between len_char and len_durations, {len(d)} and {real_char_length}"204205assert (206np.sum(d) == real_mel_length207), f"different between sum_durations and len_mel, {np.sum(d)} and {real_mel_length}"208209# save D to folder.210np.save(211os.path.join(args.outdir, f"{saved_name}-durations.npy"),212d.astype(np.int32),213allow_pickle=False,214)215216# save alignment to debug.217if args.save_alignment == 1:218figname = os.path.join(args.outdir, f"{saved_name}_alignment.png")219fig = plt.figure(figsize=(8, 6))220ax = fig.add_subplot(111)221ax.set_title(f"Alignment of {saved_name}")222im = ax.imshow(223alignment, aspect="auto", origin="lower", interpolation="none"224)225fig.colorbar(im, ax=ax)226xlabel = "Decoder timestep"227plt.xlabel(xlabel)228plt.ylabel("Encoder timestep")229plt.tight_layout()230plt.savefig(figname)231plt.close()232233234if __name__ == "__main__":235main()236237238