Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/multiband_melgan/decode_mb_melgan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Decode trained Mb-Melgan from folder."""
16
17
import argparse
18
import logging
19
import os
20
21
import numpy as np
22
import soundfile as sf
23
import yaml
24
from tqdm import tqdm
25
26
from tensorflow_tts.configs import MultiBandMelGANGeneratorConfig
27
from tensorflow_tts.datasets import MelDataset
28
from tensorflow_tts.models import TFPQMF, TFMelGANGenerator
29
30
31
def main():
32
"""Run melgan decoding from folder."""
33
parser = argparse.ArgumentParser(
34
description="Generate Audio from melspectrogram with trained melgan "
35
"(See detail in example/melgan/decode_melgan.py)."
36
)
37
parser.add_argument(
38
"--rootdir",
39
default=None,
40
type=str,
41
required=True,
42
help="directory including ids/durations files.",
43
)
44
parser.add_argument(
45
"--outdir", type=str, required=True, help="directory to save generated speech."
46
)
47
parser.add_argument(
48
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
49
)
50
parser.add_argument(
51
"--use-norm", type=int, default=1, help="Use norm or raw melspectrogram."
52
)
53
parser.add_argument("--batch-size", type=int, default=8, help="batch_size.")
54
parser.add_argument(
55
"--config",
56
default=None,
57
type=str,
58
required=True,
59
help="yaml format configuration file. if not explicitly provided, "
60
"it will be searched in the checkpoint directory. (default=None)",
61
)
62
parser.add_argument(
63
"--verbose",
64
type=int,
65
default=1,
66
help="logging level. higher is more logging. (default=1)",
67
)
68
args = parser.parse_args()
69
70
# set logger
71
if args.verbose > 1:
72
logging.basicConfig(
73
level=logging.DEBUG,
74
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
75
)
76
elif args.verbose > 0:
77
logging.basicConfig(
78
level=logging.INFO,
79
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
80
)
81
else:
82
logging.basicConfig(
83
level=logging.WARN,
84
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
85
)
86
logging.warning("Skip DEBUG/INFO messages")
87
88
# check directory existence
89
if not os.path.exists(args.outdir):
90
os.makedirs(args.outdir)
91
92
# load config
93
with open(args.config) as f:
94
config = yaml.load(f, Loader=yaml.Loader)
95
config.update(vars(args))
96
97
if config["format"] == "npy":
98
mel_query = "*-fs-after-feats.npy" if "fastspeech" in args.rootdir else "*-norm-feats.npy" if args.use_norm == 1 else "*-raw-feats.npy"
99
mel_load_fn = np.load
100
else:
101
raise ValueError("Only npy is supported.")
102
103
# define data-loader
104
dataset = MelDataset(
105
root_dir=args.rootdir,
106
mel_query=mel_query,
107
mel_load_fn=mel_load_fn,
108
)
109
dataset = dataset.create(batch_size=args.batch_size)
110
111
# define model and load checkpoint
112
mb_melgan = TFMelGANGenerator(
113
config=MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]),
114
name="multiband_melgan_generator",
115
)
116
mb_melgan._build()
117
mb_melgan.load_weights(args.checkpoint)
118
119
pqmf = TFPQMF(
120
config=MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]), name="pqmf"
121
)
122
123
for data in tqdm(dataset, desc="[Decoding]"):
124
utt_ids, mels, mel_lengths = data["utt_ids"], data["mels"], data["mel_lengths"]
125
126
# melgan inference.
127
generated_subbands = mb_melgan(mels)
128
generated_audios = pqmf.synthesis(generated_subbands)
129
130
# convert to numpy.
131
generated_audios = generated_audios.numpy() # [B, T]
132
133
# save to outdir
134
for i, audio in enumerate(generated_audios):
135
utt_id = utt_ids[i].numpy().decode("utf-8")
136
sf.write(
137
os.path.join(args.outdir, f"{utt_id}.wav"),
138
audio[: mel_lengths[i].numpy() * config["hop_size"]],
139
config["sampling_rate"],
140
"PCM_16",
141
)
142
143
144
if __name__ == "__main__":
145
main()
146
147