Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/parallel_wavegan/decode_parallel_wavegan.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Decode trained Mb-Melgan from folder."""
16
17
import argparse
18
import logging
19
import os
20
21
import numpy as np
22
import soundfile as sf
23
import yaml
24
from tqdm import tqdm
25
26
from tensorflow_tts.configs import ParallelWaveGANGeneratorConfig
27
from tensorflow_tts.datasets import MelDataset
28
from tensorflow_tts.models import TFParallelWaveGANGenerator
29
30
31
def main():
32
"""Run parallel_wavegan decoding from folder."""
33
parser = argparse.ArgumentParser(
34
description="Generate Audio from melspectrogram with trained melgan "
35
"(See detail in examples/parallel_wavegan/decode_parallel_wavegan.py)."
36
)
37
parser.add_argument(
38
"--rootdir",
39
default=None,
40
type=str,
41
required=True,
42
help="directory including ids/durations files.",
43
)
44
parser.add_argument(
45
"--outdir", type=str, required=True, help="directory to save generated speech."
46
)
47
parser.add_argument(
48
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
49
)
50
parser.add_argument(
51
"--use-norm", type=int, default=1, help="Use norm or raw melspectrogram."
52
)
53
parser.add_argument("--batch-size", type=int, default=8, help="batch_size.")
54
parser.add_argument(
55
"--config",
56
default=None,
57
type=str,
58
required=True,
59
help="yaml format configuration file. if not explicitly provided, "
60
"it will be searched in the checkpoint directory. (default=None)",
61
)
62
parser.add_argument(
63
"--verbose",
64
type=int,
65
default=1,
66
help="logging level. higher is more logging. (default=1)",
67
)
68
args = parser.parse_args()
69
70
# set logger
71
if args.verbose > 1:
72
logging.basicConfig(
73
level=logging.DEBUG,
74
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
75
)
76
elif args.verbose > 0:
77
logging.basicConfig(
78
level=logging.INFO,
79
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
80
)
81
else:
82
logging.basicConfig(
83
level=logging.WARN,
84
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
85
)
86
logging.warning("Skip DEBUG/INFO messages")
87
88
# check directory existence
89
if not os.path.exists(args.outdir):
90
os.makedirs(args.outdir)
91
92
# load config
93
with open(args.config) as f:
94
config = yaml.load(f, Loader=yaml.Loader)
95
config.update(vars(args))
96
97
if config["format"] == "npy":
98
mel_query = "*-fs-after-feats.npy" if "fastspeech" in args.rootdir else "*-norm-feats.npy" if args.use_norm == 1 else "*-raw-feats.npy"
99
mel_load_fn = np.load
100
else:
101
raise ValueError("Only npy is supported.")
102
103
# define data-loader
104
dataset = MelDataset(
105
root_dir=args.rootdir,
106
mel_query=mel_query,
107
mel_load_fn=mel_load_fn,
108
)
109
dataset = dataset.create(batch_size=args.batch_size)
110
111
# define model and load checkpoint
112
parallel_wavegan = TFParallelWaveGANGenerator(
113
config=ParallelWaveGANGeneratorConfig(**config["parallel_wavegan_generator_params"]),
114
name="parallel_wavegan_generator",
115
)
116
parallel_wavegan._build()
117
parallel_wavegan.load_weights(args.checkpoint)
118
119
for data in tqdm(dataset, desc="[Decoding]"):
120
utt_ids, mels, mel_lengths = data["utt_ids"], data["mels"], data["mel_lengths"]
121
122
# pwgan inference.
123
generated_audios = parallel_wavegan.inference(mels)
124
125
# convert to numpy.
126
generated_audios = generated_audios.numpy() # [B, T]
127
128
# save to outdir
129
for i, audio in enumerate(generated_audios):
130
utt_id = utt_ids[i].numpy().decode("utf-8")
131
sf.write(
132
os.path.join(args.outdir, f"{utt_id}.wav"),
133
audio[: mel_lengths[i].numpy() * config["hop_size"]],
134
config["sampling_rate"],
135
"PCM_16",
136
)
137
138
139
if __name__ == "__main__":
140
main()
141
142