Path: blob/master/examples/tacotron2/decode_tacotron2.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Decode Tacotron-2."""1516import argparse17import logging18import os19import sys2021sys.path.append(".")2223import numpy as np24import tensorflow as tf25import yaml26from tqdm import tqdm27import matplotlib.pyplot as plt2829from examples.tacotron2.tacotron_dataset import CharactorMelDataset30from tensorflow_tts.configs import Tacotron2Config31from tensorflow_tts.models import TFTacotron2323334def main():35"""Running decode tacotron-2 mel-spectrogram."""36parser = argparse.ArgumentParser(37description="Decode mel-spectrogram from folder ids with trained Tacotron-2 "38"(See detail in tensorflow_tts/example/tacotron2/decode_tacotron2.py)."39)40parser.add_argument(41"--rootdir",42default=None,43type=str,44required=True,45help="directory including ids/durations files.",46)47parser.add_argument(48"--outdir", type=str, required=True, help="directory to save generated speech."49)50parser.add_argument(51"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."52)53parser.add_argument(54"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."55)56parser.add_argument("--batch-size", default=8, type=int, help="batch size.")57parser.add_argument("--win-front", default=3, type=int, help="win-front.")58parser.add_argument("--win-back", default=3, type=int, help="win-front.")59parser.add_argument(60"--config",61default=None,62type=str,63required=True,64help="yaml format configuration file. if not explicitly provided, "65"it will be searched in the checkpoint directory. (default=None)",66)67parser.add_argument(68"--verbose",69type=int,70default=1,71help="logging level. higher is more logging. (default=1)",72)73args = parser.parse_args()7475# set logger76if args.verbose > 1:77logging.basicConfig(78level=logging.DEBUG,79format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",80)81elif args.verbose > 0:82logging.basicConfig(83level=logging.INFO,84format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",85)86else:87logging.basicConfig(88level=logging.WARN,89format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",90)91logging.warning("Skip DEBUG/INFO messages")9293# check directory existence94if not os.path.exists(args.outdir):95os.makedirs(args.outdir)9697# load config98with open(args.config) as f:99config = yaml.load(f, Loader=yaml.Loader)100config.update(vars(args))101102if config["format"] == "npy":103char_query = "*-ids.npy"104mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"105char_load_fn = np.load106mel_load_fn = np.load107else:108raise ValueError("Only npy is supported.")109110# define data-loader111dataset = CharactorMelDataset(112dataset=config["tacotron2_params"]["dataset"],113root_dir=args.rootdir,114charactor_query=char_query,115mel_query=mel_query,116charactor_load_fn=char_load_fn,117mel_load_fn=mel_load_fn,118reduction_factor=config["tacotron2_params"]["reduction_factor"]119)120dataset = dataset.create(allow_cache=True, batch_size=args.batch_size)121122# define model and load checkpoint123tacotron2 = TFTacotron2(124config=Tacotron2Config(**config["tacotron2_params"]),125name="tacotron2",126)127tacotron2._build() # build model to be able load_weights.128tacotron2.load_weights(args.checkpoint)129130# setup window131tacotron2.setup_window(win_front=args.win_front, win_back=args.win_back)132133for data in tqdm(dataset, desc="[Decoding]"):134utt_ids = data["utt_ids"]135utt_ids = utt_ids.numpy()136137# tacotron2 inference.138(139mel_outputs,140post_mel_outputs,141stop_outputs,142alignment_historys,143) = tacotron2.inference(144input_ids=data["input_ids"],145input_lengths=data["input_lengths"],146speaker_ids=data["speaker_ids"],147)148149# convert to numpy150post_mel_outputs = post_mel_outputs.numpy()151152for i, post_mel_output in enumerate(post_mel_outputs):153stop_token = tf.math.round(tf.nn.sigmoid(stop_outputs[i])) # [T]154real_length = tf.math.reduce_sum(155tf.cast(tf.math.equal(stop_token, 0.0), tf.int32), -1156)157post_mel_output = post_mel_output[:real_length, :]158159saved_name = utt_ids[i].decode("utf-8")160161# save D to folder.162np.save(163os.path.join(args.outdir, f"{saved_name}-norm-feats.npy"),164post_mel_output.astype(np.float32),165allow_pickle=False,166)167168169if __name__ == "__main__":170main()171172173