Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/tacotron2/extract_postnets.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Extract durations based-on tacotron-2 alignments for FastSpeech."""
16
17
import argparse
18
import logging
19
import os
20
from numba import jit
21
import sys
22
23
sys.path.append(".")
24
25
import matplotlib.pyplot as plt
26
import numpy as np
27
import tensorflow as tf
28
import yaml
29
from tqdm import tqdm
30
31
from examples.tacotron2.tacotron_dataset import CharactorMelDataset
32
from tensorflow_tts.configs import Tacotron2Config
33
from tensorflow_tts.models import TFTacotron2
34
35
36
@jit(nopython=True)
37
def get_duration_from_alignment(alignment):
38
D = np.array([0 for _ in range(np.shape(alignment)[0])])
39
40
for i in range(np.shape(alignment)[1]):
41
max_index = list(alignment[:, i]).index(alignment[:, i].max())
42
D[max_index] = D[max_index] + 1
43
44
return D
45
46
47
def main():
48
"""Running extract tacotron-2 durations."""
49
parser = argparse.ArgumentParser(
50
description="Extract durations from charactor with trained Tacotron-2 "
51
"(See detail in tensorflow_tts/example/tacotron-2/extract_duration.py)."
52
)
53
parser.add_argument(
54
"--rootdir",
55
default=None,
56
type=str,
57
required=True,
58
help="directory including ids/durations files.",
59
)
60
parser.add_argument(
61
"--outdir", type=str, required=True, help="directory to save generated mels."
62
)
63
parser.add_argument(
64
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
65
)
66
parser.add_argument(
67
"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."
68
)
69
parser.add_argument("--batch-size", default=32, type=int, help="batch size.")
70
parser.add_argument("--win-front", default=3, type=int, help="win-front.")
71
parser.add_argument("--win-back", default=3, type=int, help="win-front.")
72
parser.add_argument(
73
"--use-window-mask", default=1, type=int, help="toggle window masking."
74
)
75
parser.add_argument("--save-alignment", default=0, type=int, help="save-alignment.")
76
parser.add_argument(
77
"--config",
78
default=None,
79
type=str,
80
required=True,
81
help="yaml format configuration file. if not explicitly provided, "
82
"it will be searched in the checkpoint directory. (default=None)",
83
)
84
parser.add_argument(
85
"--verbose",
86
type=int,
87
default=1,
88
help="logging level. higher is more logging. (default=1)",
89
)
90
args = parser.parse_args()
91
92
# set logger
93
if args.verbose > 1:
94
logging.basicConfig(
95
level=logging.DEBUG,
96
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
97
)
98
elif args.verbose > 0:
99
logging.basicConfig(
100
level=logging.INFO,
101
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
102
)
103
else:
104
logging.basicConfig(
105
level=logging.WARN,
106
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
107
)
108
logging.warning("Skip DEBUG/INFO messages")
109
110
# check directory existence
111
if not os.path.exists(args.outdir):
112
os.makedirs(args.outdir)
113
114
# load config
115
with open(args.config) as f:
116
config = yaml.load(f, Loader=yaml.Loader)
117
config.update(vars(args))
118
119
if config["format"] == "npy":
120
char_query = "*-ids.npy"
121
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
122
char_load_fn = np.load
123
mel_load_fn = np.load
124
else:
125
raise ValueError("Only npy is supported.")
126
127
# define data-loader
128
dataset = CharactorMelDataset(
129
dataset=config["tacotron2_params"]["dataset"],
130
root_dir=args.rootdir,
131
charactor_query=char_query,
132
mel_query=mel_query,
133
charactor_load_fn=char_load_fn,
134
mel_load_fn=mel_load_fn,
135
reduction_factor=config["tacotron2_params"]["reduction_factor"],
136
use_fixed_shapes=True,
137
)
138
dataset = dataset.create(
139
allow_cache=True, batch_size=args.batch_size, drop_remainder=False
140
)
141
142
# define model and load checkpoint
143
tacotron2 = TFTacotron2(
144
config=Tacotron2Config(**config["tacotron2_params"]),
145
name="tacotron2",
146
)
147
tacotron2._build() # build model to be able load_weights.
148
tacotron2.load_weights(args.checkpoint)
149
150
# apply tf.function for tacotron2.
151
tacotron2 = tf.function(tacotron2, experimental_relax_shapes=True)
152
153
for data in tqdm(dataset, desc="[Extract Postnets]"):
154
utt_ids = data["utt_ids"]
155
input_lengths = data["input_lengths"]
156
mel_lengths = data["mel_lengths"]
157
utt_ids = utt_ids.numpy()
158
real_mel_lengths = data["real_mel_lengths"]
159
mel_gt = data["mel_gts"]
160
del data["real_mel_lengths"]
161
162
# tacotron2 inference.
163
mel_outputs, post_mel_outputs, stop_outputs, alignment_historys = tacotron2(
164
**data,
165
use_window_mask=args.use_window_mask,
166
win_front=args.win_front,
167
win_back=args.win_back,
168
training=True,
169
)
170
171
# convert to numpy
172
alignment_historys = alignment_historys.numpy()
173
post_mel_outputs = post_mel_outputs.numpy()
174
mel_gt = mel_gt.numpy()
175
176
outdpost = os.path.join(args.outdir, "postnets")
177
178
if not os.path.exists(outdpost):
179
os.makedirs(outdpost)
180
181
for i, alignment in enumerate(alignment_historys):
182
real_char_length = input_lengths[i].numpy()
183
real_mel_length = real_mel_lengths[i].numpy()
184
alignment_mel_length = int(
185
np.ceil(
186
real_mel_length / config["tacotron2_params"]["reduction_factor"]
187
)
188
)
189
alignment = alignment[:real_char_length, :alignment_mel_length]
190
d = get_duration_from_alignment(alignment) # [max_char_len]
191
192
d = d * config["tacotron2_params"]["reduction_factor"]
193
assert (
194
np.sum(d) >= real_mel_length
195
), f"{d}, {np.sum(d)}, {alignment_mel_length}, {real_mel_length}"
196
if np.sum(d) > real_mel_length:
197
rest = np.sum(d) - real_mel_length
198
# print(d, np.sum(d), real_mel_length)
199
if d[-1] > rest:
200
d[-1] -= rest
201
elif d[0] > rest:
202
d[0] -= rest
203
else:
204
d[-1] -= rest // 2
205
d[0] -= rest - rest // 2
206
207
assert d[-1] >= 0 and d[0] >= 0, f"{d}, {np.sum(d)}, {real_mel_length}"
208
209
saved_name = utt_ids[i].decode("utf-8")
210
211
# check a length compatible
212
assert (
213
len(d) == real_char_length
214
), f"different between len_char and len_durations, {len(d)} and {real_char_length}"
215
216
assert (
217
np.sum(d) == real_mel_length
218
), f"different between sum_durations and len_mel, {np.sum(d)} and {real_mel_length}"
219
220
# save D to folder.
221
222
np.save(
223
os.path.join(outdpost, f"{saved_name}-postnet.npy"),
224
post_mel_outputs[i][:][:real_mel_length].astype(np.float32),
225
allow_pickle=False,
226
)
227
228
# save alignment to debug.
229
if args.save_alignment == 1:
230
figname = os.path.join(args.outdir, f"{saved_name}_alignment.png")
231
fig = plt.figure(figsize=(8, 6))
232
ax = fig.add_subplot(111)
233
ax.set_title(f"Alignment of {saved_name}")
234
im = ax.imshow(
235
alignment, aspect="auto", origin="lower", interpolation="none"
236
)
237
fig.colorbar(im, ax=ax)
238
xlabel = "Decoder timestep"
239
plt.xlabel(xlabel)
240
plt.ylabel("Encoder timestep")
241
plt.tight_layout()
242
plt.savefig(figname)
243
plt.close()
244
245
246
if __name__ == "__main__":
247
main()
248
249