Path: blob/master/examples/fastspeech/decode_fastspeech.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Decode trained FastSpeech from folders."""1516import argparse17import logging18import os19import sys2021sys.path.append(".")2223import numpy as np24import tensorflow as tf25import yaml26from tqdm import tqdm2728from examples.fastspeech.fastspeech_dataset import CharactorDataset29from tensorflow_tts.configs import FastSpeechConfig30from tensorflow_tts.models import TFFastSpeech313233def main():34"""Run fastspeech decoding from folder."""35parser = argparse.ArgumentParser(36description="Decode soft-mel features from charactor with trained FastSpeech "37"(See detail in examples/fastspeech/decode_fastspeech.py)."38)39parser.add_argument(40"--rootdir",41default=None,42type=str,43required=True,44help="directory including ids/durations files.",45)46parser.add_argument(47"--outdir", type=str, required=True, help="directory to save generated speech."48)49parser.add_argument(50"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."51)52parser.add_argument(53"--config",54default=None,55type=str,56required=True,57help="yaml format configuration file. if not explicitly provided, "58"it will be searched in the checkpoint directory. (default=None)",59)60parser.add_argument(61"--batch-size",62default=8,63type=int,64required=False,65help="Batch size for inference.",66)67parser.add_argument(68"--verbose",69type=int,70default=1,71help="logging level. higher is more logging. (default=1)",72)73args = parser.parse_args()7475# set logger76if args.verbose > 1:77logging.basicConfig(78level=logging.DEBUG,79format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",80)81elif args.verbose > 0:82logging.basicConfig(83level=logging.INFO,84format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",85)86else:87logging.basicConfig(88level=logging.WARN,89format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",90)91logging.warning("Skip DEBUG/INFO messages")9293# check directory existence94if not os.path.exists(args.outdir):95os.makedirs(args.outdir)9697# load config98with open(args.config) as f:99config = yaml.load(f, Loader=yaml.Loader)100config.update(vars(args))101102if config["format"] == "npy":103char_query = "*-ids.npy"104char_load_fn = np.load105else:106raise ValueError("Only npy is supported.")107108# define data-loader109dataset = CharactorDataset(110root_dir=args.rootdir,111charactor_query=char_query,112charactor_load_fn=char_load_fn,113)114dataset = dataset.create(batch_size=args.batch_size)115116# define model and load checkpoint117fastspeech = TFFastSpeech(118config=FastSpeechConfig(**config["fastspeech_params"]), name="fastspeech"119)120fastspeech._build()121fastspeech.load_weights(args.checkpoint)122123for data in tqdm(dataset, desc="Decoding"):124utt_ids = data["utt_ids"]125char_ids = data["input_ids"]126127# fastspeech inference.128masked_mel_before, masked_mel_after, duration_outputs = fastspeech.inference(129char_ids,130speaker_ids=tf.zeros(shape=[tf.shape(char_ids)[0]], dtype=tf.int32),131speed_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),132)133134# convert to numpy135masked_mel_befores = masked_mel_before.numpy()136masked_mel_afters = masked_mel_after.numpy()137138for (utt_id, mel_before, mel_after, durations) in zip(139utt_ids, masked_mel_befores, masked_mel_afters, duration_outputs140):141# real len of mel predicted142real_length = durations.numpy().sum()143utt_id = utt_id.numpy().decode("utf-8")144# save to folder.145np.save(146os.path.join(args.outdir, f"{utt_id}-fs-before-feats.npy"),147mel_before[:real_length, :].astype(np.float32),148allow_pickle=False,149)150np.save(151os.path.join(args.outdir, f"{utt_id}-fs-after-feats.npy"),152mel_after[:real_length, :].astype(np.float32),153allow_pickle=False,154)155156157if __name__ == "__main__":158main()159160161