Path: blob/master/examples/melgan/decode_melgan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Decode trained Melgan from folder."""1516import argparse17import logging18import os19import sys2021sys.path.append(".")2223import numpy as np24import soundfile as sf25import yaml26from tqdm import tqdm2728from tensorflow_tts.configs import MelGANGeneratorConfig29from tensorflow_tts.datasets import MelDataset30from tensorflow_tts.models import TFMelGANGenerator313233def main():34"""Run melgan decoding from folder."""35parser = argparse.ArgumentParser(36description="Generate Audio from melspectrogram with trained melgan "37"(See detail in example/melgan/decode_melgan.py)."38)39parser.add_argument(40"--rootdir",41default=None,42type=str,43required=True,44help="directory including ids/durations files.",45)46parser.add_argument(47"--outdir", type=str, required=True, help="directory to save generated speech."48)49parser.add_argument(50"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."51)52parser.add_argument(53"--use-norm", type=int, default=1, help="Use norm or raw melspectrogram."54)55parser.add_argument("--batch-size", type=int, default=8, help="batch_size.")56parser.add_argument(57"--config",58default=None,59type=str,60required=True,61help="yaml format configuration file. if not explicitly provided, "62"it will be searched in the checkpoint directory. (default=None)",63)64parser.add_argument(65"--verbose",66type=int,67default=1,68help="logging level. higher is more logging. (default=1)",69)70args = parser.parse_args()7172# set logger73if args.verbose > 1:74logging.basicConfig(75level=logging.DEBUG,76format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",77)78elif args.verbose > 0:79logging.basicConfig(80level=logging.INFO,81format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",82)83else:84logging.basicConfig(85level=logging.WARN,86format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",87)88logging.warning("Skip DEBUG/INFO messages")8990# check directory existence91if not os.path.exists(args.outdir):92os.makedirs(args.outdir)9394# load config95with open(args.config) as f:96config = yaml.load(f, Loader=yaml.Loader)97config.update(vars(args))9899if config["format"] == "npy":100mel_query = "*-norm-feats.npy" if args.use_norm == 1 else "*-raw-feats.npy"101mel_load_fn = np.load102else:103raise ValueError("Only npy is supported.")104105# define data-loader106dataset = MelDataset(107root_dir=args.rootdir,108mel_query=mel_query,109mel_load_fn=mel_load_fn,110)111dataset = dataset.create(batch_size=args.batch_size)112113# define model and load checkpoint114melgan = TFMelGANGenerator(115config=MelGANGeneratorConfig(**config["melgan_generator_params"]), name="melgan_generator"116)117melgan._build()118melgan.load_weights(args.checkpoint)119120for data in tqdm(dataset, desc="[Decoding]"):121utt_ids, mels, mel_lengths = data["utt_ids"], data["mels"], data["mel_lengths"]122# melgan inference.123generated_audios = melgan(mels)124125# convert to numpy.126generated_audios = generated_audios.numpy() # [B, T]127128# save to outdir129for i, audio in enumerate(generated_audios):130utt_id = utt_ids[i].numpy().decode("utf-8")131sf.write(132os.path.join(args.outdir, f"{utt_id}.wav"),133audio[: mel_lengths[i].numpy() * config["hop_size"]],134config["sampling_rate"],135"PCM_16",136)137138139if __name__ == "__main__":140main()141142143