Path: blob/master/examples/fastspeech2/decode_fastspeech2.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Decode trained FastSpeech from folders."""1516import argparse17import logging18import os19import sys2021sys.path.append(".")2223import numpy as np24import tensorflow as tf25import yaml26from tqdm import tqdm2728from examples.fastspeech.fastspeech_dataset import CharactorDataset29from tensorflow_tts.configs import FastSpeech2Config30from tensorflow_tts.models import TFFastSpeech2313233def main():34"""Run fastspeech2 decoding from folder."""35parser = argparse.ArgumentParser(36description="Decode soft-mel features from charactor with trained FastSpeech "37"(See detail in examples/fastspeech2/decode_fastspeech2.py)."38)39parser.add_argument(40"--rootdir",41default=None,42type=str,43required=True,44help="directory including ids/durations files.",45)46parser.add_argument(47"--outdir", type=str, required=True, help="directory to save generated speech."48)49parser.add_argument(50"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."51)52parser.add_argument(53"--config",54default=None,55type=str,56required=True,57help="yaml format configuration file. if not explicitly provided, "58"it will be searched in the checkpoint directory. (default=None)",59)60parser.add_argument(61"--batch-size",62default=8,63type=int,64required=False,65help="Batch size for inference.",66)67parser.add_argument(68"--verbose",69type=int,70default=1,71help="logging level. higher is more logging. (default=1)",72)73args = parser.parse_args()7475# set logger76if args.verbose > 1:77logging.basicConfig(78level=logging.DEBUG,79format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",80)81elif args.verbose > 0:82logging.basicConfig(83level=logging.INFO,84format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",85)86else:87logging.basicConfig(88level=logging.WARN,89format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",90)91logging.warning("Skip DEBUG/INFO messages")9293# check directory existence94if not os.path.exists(args.outdir):95os.makedirs(args.outdir)9697# load config98with open(args.config) as f:99config = yaml.load(f, Loader=yaml.Loader)100config.update(vars(args))101102if config["format"] == "npy":103char_query = "*-ids.npy"104char_load_fn = np.load105else:106raise ValueError("Only npy is supported.")107108# define data-loader109dataset = CharactorDataset(110root_dir=args.rootdir,111charactor_query=char_query,112charactor_load_fn=char_load_fn,113)114dataset = dataset.create(batch_size=args.batch_size)115116# define model and load checkpoint117fastspeech2 = TFFastSpeech2(118config=FastSpeech2Config(**config["fastspeech2_params"]), name="fastspeech2"119)120fastspeech2._build()121fastspeech2.load_weights(args.checkpoint)122123for data in tqdm(dataset, desc="Decoding"):124utt_ids = data["utt_ids"]125char_ids = data["input_ids"]126127# fastspeech inference.128(129masked_mel_before,130masked_mel_after,131duration_outputs,132_,133_,134) = fastspeech2.inference(135char_ids,136speaker_ids=tf.zeros(shape=[tf.shape(char_ids)[0]], dtype=tf.int32),137speed_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),138f0_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),139energy_ratios=tf.ones(shape=[tf.shape(char_ids)[0]], dtype=tf.float32),140)141142# convert to numpy143masked_mel_befores = masked_mel_before.numpy()144masked_mel_afters = masked_mel_after.numpy()145146for (utt_id, mel_before, mel_after, durations) in zip(147utt_ids, masked_mel_befores, masked_mel_afters, duration_outputs148):149# real len of mel predicted150real_length = durations.numpy().sum()151utt_id = utt_id.numpy().decode("utf-8")152# save to folder.153np.save(154os.path.join(args.outdir, f"{utt_id}-fs-before-feats.npy"),155mel_before[:real_length, :].astype(np.float32),156allow_pickle=False,157)158np.save(159os.path.join(args.outdir, f"{utt_id}-fs-after-feats.npy"),160mel_after[:real_length, :].astype(np.float32),161allow_pickle=False,162)163164165if __name__ == "__main__":166main()167168169