Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/tacotron2/decode_tacotron2.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Decode Tacotron-2."""
16
17
import argparse
18
import logging
19
import os
20
import sys
21
22
sys.path.append(".")
23
24
import numpy as np
25
import tensorflow as tf
26
import yaml
27
from tqdm import tqdm
28
import matplotlib.pyplot as plt
29
30
from examples.tacotron2.tacotron_dataset import CharactorMelDataset
31
from tensorflow_tts.configs import Tacotron2Config
32
from tensorflow_tts.models import TFTacotron2
33
34
35
def main():
36
"""Running decode tacotron-2 mel-spectrogram."""
37
parser = argparse.ArgumentParser(
38
description="Decode mel-spectrogram from folder ids with trained Tacotron-2 "
39
"(See detail in tensorflow_tts/example/tacotron2/decode_tacotron2.py)."
40
)
41
parser.add_argument(
42
"--rootdir",
43
default=None,
44
type=str,
45
required=True,
46
help="directory including ids/durations files.",
47
)
48
parser.add_argument(
49
"--outdir", type=str, required=True, help="directory to save generated speech."
50
)
51
parser.add_argument(
52
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
53
)
54
parser.add_argument(
55
"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."
56
)
57
parser.add_argument("--batch-size", default=8, type=int, help="batch size.")
58
parser.add_argument("--win-front", default=3, type=int, help="win-front.")
59
parser.add_argument("--win-back", default=3, type=int, help="win-front.")
60
parser.add_argument(
61
"--config",
62
default=None,
63
type=str,
64
required=True,
65
help="yaml format configuration file. if not explicitly provided, "
66
"it will be searched in the checkpoint directory. (default=None)",
67
)
68
parser.add_argument(
69
"--verbose",
70
type=int,
71
default=1,
72
help="logging level. higher is more logging. (default=1)",
73
)
74
args = parser.parse_args()
75
76
# set logger
77
if args.verbose > 1:
78
logging.basicConfig(
79
level=logging.DEBUG,
80
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
81
)
82
elif args.verbose > 0:
83
logging.basicConfig(
84
level=logging.INFO,
85
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
86
)
87
else:
88
logging.basicConfig(
89
level=logging.WARN,
90
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
91
)
92
logging.warning("Skip DEBUG/INFO messages")
93
94
# check directory existence
95
if not os.path.exists(args.outdir):
96
os.makedirs(args.outdir)
97
98
# load config
99
with open(args.config) as f:
100
config = yaml.load(f, Loader=yaml.Loader)
101
config.update(vars(args))
102
103
if config["format"] == "npy":
104
char_query = "*-ids.npy"
105
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
106
char_load_fn = np.load
107
mel_load_fn = np.load
108
else:
109
raise ValueError("Only npy is supported.")
110
111
# define data-loader
112
dataset = CharactorMelDataset(
113
dataset=config["tacotron2_params"]["dataset"],
114
root_dir=args.rootdir,
115
charactor_query=char_query,
116
mel_query=mel_query,
117
charactor_load_fn=char_load_fn,
118
mel_load_fn=mel_load_fn,
119
reduction_factor=config["tacotron2_params"]["reduction_factor"]
120
)
121
dataset = dataset.create(allow_cache=True, batch_size=args.batch_size)
122
123
# define model and load checkpoint
124
tacotron2 = TFTacotron2(
125
config=Tacotron2Config(**config["tacotron2_params"]),
126
name="tacotron2",
127
)
128
tacotron2._build() # build model to be able load_weights.
129
tacotron2.load_weights(args.checkpoint)
130
131
# setup window
132
tacotron2.setup_window(win_front=args.win_front, win_back=args.win_back)
133
134
for data in tqdm(dataset, desc="[Decoding]"):
135
utt_ids = data["utt_ids"]
136
utt_ids = utt_ids.numpy()
137
138
# tacotron2 inference.
139
(
140
mel_outputs,
141
post_mel_outputs,
142
stop_outputs,
143
alignment_historys,
144
) = tacotron2.inference(
145
input_ids=data["input_ids"],
146
input_lengths=data["input_lengths"],
147
speaker_ids=data["speaker_ids"],
148
)
149
150
# convert to numpy
151
post_mel_outputs = post_mel_outputs.numpy()
152
153
for i, post_mel_output in enumerate(post_mel_outputs):
154
stop_token = tf.math.round(tf.nn.sigmoid(stop_outputs[i])) # [T]
155
real_length = tf.math.reduce_sum(
156
tf.cast(tf.math.equal(stop_token, 0.0), tf.int32), -1
157
)
158
post_mel_output = post_mel_output[:real_length, :]
159
160
saved_name = utt_ids[i].decode("utf-8")
161
162
# save D to folder.
163
np.save(
164
os.path.join(args.outdir, f"{saved_name}-norm-feats.npy"),
165
post_mel_output.astype(np.float32),
166
allow_pickle=False,
167
)
168
169
170
if __name__ == "__main__":
171
main()
172
173