Path: blob/master/examples/multiband_melgan_hf/decode_mb_melgan.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Decode trained Mb-Melgan from folder."""1516import argparse17import logging18import os1920import numpy as np21import soundfile as sf22import yaml23from tqdm import tqdm2425from tensorflow_tts.configs import MultiBandMelGANGeneratorConfig26from tensorflow_tts.datasets import MelDataset27from tensorflow_tts.models import TFPQMF, TFMelGANGenerator282930def main():31"""Run melgan decoding from folder."""32parser = argparse.ArgumentParser(33description="Generate Audio from melspectrogram with trained melgan "34"(See detail in example/melgan/decode_melgan.py)."35)36parser.add_argument(37"--rootdir",38default=None,39type=str,40required=True,41help="directory including ids/durations files.",42)43parser.add_argument(44"--outdir", type=str, required=True, help="directory to save generated speech."45)46parser.add_argument(47"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."48)49parser.add_argument(50"--use-norm", type=int, default=1, help="Use norm or raw melspectrogram."51)52parser.add_argument("--batch-size", type=int, default=8, help="batch_size.")53parser.add_argument(54"--config",55default=None,56type=str,57required=True,58help="yaml format configuration file. if not explicitly provided, "59"it will be searched in the checkpoint directory. (default=None)",60)61parser.add_argument(62"--verbose",63type=int,64default=1,65help="logging level. higher is more logging. (default=1)",66)67args = parser.parse_args()6869# set logger70if args.verbose > 1:71logging.basicConfig(72level=logging.DEBUG,73format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",74)75elif args.verbose > 0:76logging.basicConfig(77level=logging.INFO,78format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",79)80else:81logging.basicConfig(82level=logging.WARN,83format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",84)85logging.warning("Skip DEBUG/INFO messages")8687# check directory existence88if not os.path.exists(args.outdir):89os.makedirs(args.outdir)9091# load config92with open(args.config) as f:93config = yaml.load(f, Loader=yaml.Loader)94config.update(vars(args))9596if config["format"] == "npy":97mel_query = "*-fs-after-feats.npy" if "fastspeech" in args.rootdir else "*-norm-feats.npy" if args.use_norm == 1 else "*-raw-feats.npy"98mel_load_fn = np.load99else:100raise ValueError("Only npy is supported.")101102# define data-loader103dataset = MelDataset(104root_dir=args.rootdir,105mel_query=mel_query,106mel_load_fn=mel_load_fn,107)108dataset = dataset.create(batch_size=args.batch_size)109110# define model and load checkpoint111mb_melgan = TFMelGANGenerator(112config=MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]),113name="multiband_melgan_generator",114)115mb_melgan._build()116mb_melgan.load_weights(args.checkpoint)117118pqmf = TFPQMF(119config=MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]), name="pqmf"120)121122for data in tqdm(dataset, desc="[Decoding]"):123utt_ids, mels, mel_lengths = data["utt_ids"], data["mels"], data["mel_lengths"]124125# melgan inference.126generated_subbands = mb_melgan(mels)127generated_audios = pqmf.synthesis(generated_subbands)128129# convert to numpy.130generated_audios = generated_audios.numpy() # [B, T]131132# save to outdir133for i, audio in enumerate(generated_audios):134utt_id = utt_ids[i].numpy().decode("utf-8")135sf.write(136os.path.join(args.outdir, f"{utt_id}.wav"),137audio[: mel_lengths[i].numpy() * config["hop_size"]],138config["sampling_rate"],139"PCM_16",140)141142143if __name__ == "__main__":144main()145146147