Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/melgan/decode_melgan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Decode trained Melgan from folder."""
16
17
import argparse
18
import logging
19
import os
20
import sys
21
22
sys.path.append(".")
23
24
import numpy as np
25
import soundfile as sf
26
import yaml
27
from tqdm import tqdm
28
29
from tensorflow_tts.configs import MelGANGeneratorConfig
30
from tensorflow_tts.datasets import MelDataset
31
from tensorflow_tts.models import TFMelGANGenerator
32
33
34
def main():
35
"""Run melgan decoding from folder."""
36
parser = argparse.ArgumentParser(
37
description="Generate Audio from melspectrogram with trained melgan "
38
"(See detail in example/melgan/decode_melgan.py)."
39
)
40
parser.add_argument(
41
"--rootdir",
42
default=None,
43
type=str,
44
required=True,
45
help="directory including ids/durations files.",
46
)
47
parser.add_argument(
48
"--outdir", type=str, required=True, help="directory to save generated speech."
49
)
50
parser.add_argument(
51
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
52
)
53
parser.add_argument(
54
"--use-norm", type=int, default=1, help="Use norm or raw melspectrogram."
55
)
56
parser.add_argument("--batch-size", type=int, default=8, help="batch_size.")
57
parser.add_argument(
58
"--config",
59
default=None,
60
type=str,
61
required=True,
62
help="yaml format configuration file. if not explicitly provided, "
63
"it will be searched in the checkpoint directory. (default=None)",
64
)
65
parser.add_argument(
66
"--verbose",
67
type=int,
68
default=1,
69
help="logging level. higher is more logging. (default=1)",
70
)
71
args = parser.parse_args()
72
73
# set logger
74
if args.verbose > 1:
75
logging.basicConfig(
76
level=logging.DEBUG,
77
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
78
)
79
elif args.verbose > 0:
80
logging.basicConfig(
81
level=logging.INFO,
82
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
83
)
84
else:
85
logging.basicConfig(
86
level=logging.WARN,
87
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
88
)
89
logging.warning("Skip DEBUG/INFO messages")
90
91
# check directory existence
92
if not os.path.exists(args.outdir):
93
os.makedirs(args.outdir)
94
95
# load config
96
with open(args.config) as f:
97
config = yaml.load(f, Loader=yaml.Loader)
98
config.update(vars(args))
99
100
if config["format"] == "npy":
101
mel_query = "*-norm-feats.npy" if args.use_norm == 1 else "*-raw-feats.npy"
102
mel_load_fn = np.load
103
else:
104
raise ValueError("Only npy is supported.")
105
106
# define data-loader
107
dataset = MelDataset(
108
root_dir=args.rootdir,
109
mel_query=mel_query,
110
mel_load_fn=mel_load_fn,
111
)
112
dataset = dataset.create(batch_size=args.batch_size)
113
114
# define model and load checkpoint
115
melgan = TFMelGANGenerator(
116
config=MelGANGeneratorConfig(**config["melgan_generator_params"]), name="melgan_generator"
117
)
118
melgan._build()
119
melgan.load_weights(args.checkpoint)
120
121
for data in tqdm(dataset, desc="[Decoding]"):
122
utt_ids, mels, mel_lengths = data["utt_ids"], data["mels"], data["mel_lengths"]
123
# melgan inference.
124
generated_audios = melgan(mels)
125
126
# convert to numpy.
127
generated_audios = generated_audios.numpy() # [B, T]
128
129
# save to outdir
130
for i, audio in enumerate(generated_audios):
131
utt_id = utt_ids[i].numpy().decode("utf-8")
132
sf.write(
133
os.path.join(args.outdir, f"{utt_id}.wav"),
134
audio[: mel_lengths[i].numpy() * config["hop_size"]],
135
config["sampling_rate"],
136
"PCM_16",
137
)
138
139
140
if __name__ == "__main__":
141
main()
142
143