Path: blob/master/examples/fastspeech2/extractfs_postnets.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Decode trained FastSpeech from folders."""1516import argparse17import logging18import os19import sys2021sys.path.append(".")2223import numpy as np24import tensorflow as tf25import yaml26from tqdm import tqdm2728from examples.fastspeech2.fastspeech2_dataset import CharactorDurationF0EnergyMelDataset29from tensorflow_tts.configs import FastSpeech2Config30from tensorflow_tts.models import TFFastSpeech2313233def main():34"""Run fastspeech2 decoding from folder."""35parser = argparse.ArgumentParser(36description="Decode soft-mel features from charactor with trained FastSpeech "37"(See detail in examples/fastspeech2/decode_fastspeech2.py)."38)39parser.add_argument(40"--rootdir",41default=None,42type=str,43required=True,44help="directory including ids/durations files.",45)46parser.add_argument(47"--outdir", type=str, required=True, help="directory to save generated speech."48)49parser.add_argument(50"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."51)52parser.add_argument(53"--config",54default=None,55type=str,56required=True,57help="yaml format configuration file. if not explicitly provided, "58"it will be searched in the checkpoint directory. (default=None)",59)60parser.add_argument(61"--batch-size",62default=8,63type=int,64required=False,65help="Batch size for inference.",66)67parser.add_argument(68"--verbose",69type=int,70default=1,71help="logging level. higher is more logging. (default=1)",72)73args = parser.parse_args()7475# set logger76if args.verbose > 1:77logging.basicConfig(78level=logging.DEBUG,79format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",80)81elif args.verbose > 0:82logging.basicConfig(83level=logging.INFO,84format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",85)86else:87logging.basicConfig(88level=logging.WARN,89format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",90)91logging.warning("Skip DEBUG/INFO messages")9293# check directory existence94if not os.path.exists(args.outdir):95os.makedirs(args.outdir)9697# load config9899outdpost = os.path.join(args.outdir, "postnets")100101if not os.path.exists(outdpost):102os.makedirs(outdpost)103104with open(args.config) as f:105config = yaml.load(f, Loader=yaml.Loader)106config.update(vars(args))107108if config["format"] == "npy":109char_query = "*-ids.npy"110char_load_fn = np.load111else:112raise ValueError("Only npy is supported.")113114# define data-loader115dataset = CharactorDurationF0EnergyMelDataset(116root_dir=args.rootdir,117charactor_query=char_query,118charactor_load_fn=char_load_fn,119)120dataset = dataset.create(121batch_size=1122) # force batch size to 1 otherwise it may miss certain files123124# define model and load checkpoint125fastspeech2 = TFFastSpeech2(126config=FastSpeech2Config(**config["fastspeech2_params"]), name="fastspeech2"127)128fastspeech2._build()129fastspeech2.load_weights(args.checkpoint)130fastspeech2 = tf.function(fastspeech2, experimental_relax_shapes=True)131132for data in tqdm(dataset, desc="Decoding"):133utt_ids = data["utt_ids"]134char_ids = data["input_ids"]135mel_lens = data["mel_lengths"]136137# fastspeech inference.138masked_mel_before, masked_mel_after, duration_outputs, _, _ = fastspeech2(139**data, training=True140)141142# convert to numpy143masked_mel_befores = masked_mel_before.numpy()144masked_mel_afters = masked_mel_after.numpy()145146for (utt_id, mel_before, mel_after, durations, mel_len) in zip(147utt_ids, masked_mel_befores, masked_mel_afters, duration_outputs, mel_lens148):149# real len of mel predicted150real_length = np.around(durations.numpy().sum()).astype(int)151utt_id = utt_id.numpy().decode("utf-8")152153np.save(154os.path.join(outdpost, f"{utt_id}-postnet.npy"),155mel_after[:mel_len, :].astype(np.float32),156allow_pickle=False,157)158159160if __name__ == "__main__":161main()162163164