Path: blob/master/examples/parallel_wavegan/decode_parallel_wavegan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Decode trained Mb-Melgan from folder."""1516import argparse17import logging18import os1920import numpy as np21import soundfile as sf22import yaml23from tqdm import tqdm2425from tensorflow_tts.configs import ParallelWaveGANGeneratorConfig26from tensorflow_tts.datasets import MelDataset27from tensorflow_tts.models import TFParallelWaveGANGenerator282930def main():31"""Run parallel_wavegan decoding from folder."""32parser = argparse.ArgumentParser(33description="Generate Audio from melspectrogram with trained melgan "34"(See detail in examples/parallel_wavegan/decode_parallel_wavegan.py)."35)36parser.add_argument(37"--rootdir",38default=None,39type=str,40required=True,41help="directory including ids/durations files.",42)43parser.add_argument(44"--outdir", type=str, required=True, help="directory to save generated speech."45)46parser.add_argument(47"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."48)49parser.add_argument(50"--use-norm", type=int, default=1, help="Use norm or raw melspectrogram."51)52parser.add_argument("--batch-size", type=int, default=8, help="batch_size.")53parser.add_argument(54"--config",55default=None,56type=str,57required=True,58help="yaml format configuration file. if not explicitly provided, "59"it will be searched in the checkpoint directory. (default=None)",60)61parser.add_argument(62"--verbose",63type=int,64default=1,65help="logging level. higher is more logging. (default=1)",66)67args = parser.parse_args()6869# set logger70if args.verbose > 1:71logging.basicConfig(72level=logging.DEBUG,73format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",74)75elif args.verbose > 0:76logging.basicConfig(77level=logging.INFO,78format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",79)80else:81logging.basicConfig(82level=logging.WARN,83format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",84)85logging.warning("Skip DEBUG/INFO messages")8687# check directory existence88if not os.path.exists(args.outdir):89os.makedirs(args.outdir)9091# load config92with open(args.config) as f:93config = yaml.load(f, Loader=yaml.Loader)94config.update(vars(args))9596if config["format"] == "npy":97mel_query = "*-fs-after-feats.npy" if "fastspeech" in args.rootdir else "*-norm-feats.npy" if args.use_norm == 1 else "*-raw-feats.npy"98mel_load_fn = np.load99else:100raise ValueError("Only npy is supported.")101102# define data-loader103dataset = MelDataset(104root_dir=args.rootdir,105mel_query=mel_query,106mel_load_fn=mel_load_fn,107)108dataset = dataset.create(batch_size=args.batch_size)109110# define model and load checkpoint111parallel_wavegan = TFParallelWaveGANGenerator(112config=ParallelWaveGANGeneratorConfig(**config["parallel_wavegan_generator_params"]),113name="parallel_wavegan_generator",114)115parallel_wavegan._build()116parallel_wavegan.load_weights(args.checkpoint)117118for data in tqdm(dataset, desc="[Decoding]"):119utt_ids, mels, mel_lengths = data["utt_ids"], data["mels"], data["mel_lengths"]120121# pwgan inference.122generated_audios = parallel_wavegan.inference(mels)123124# convert to numpy.125generated_audios = generated_audios.numpy() # [B, T]126127# save to outdir128for i, audio in enumerate(generated_audios):129utt_id = utt_ids[i].numpy().decode("utf-8")130sf.write(131os.path.join(args.outdir, f"{utt_id}.wav"),132audio[: mel_lengths[i].numpy() * config["hop_size"]],133config["sampling_rate"],134"PCM_16",135)136137138if __name__ == "__main__":139main()140141142