Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/tacotron2/extract_duration.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Extract durations based-on tacotron-2 alignments for FastSpeech."""
16
17
import argparse
18
import logging
19
import os
20
from numba import jit
21
import sys
22
23
sys.path.append(".")
24
25
import matplotlib.pyplot as plt
26
import numpy as np
27
import tensorflow as tf
28
import yaml
29
from tqdm import tqdm
30
31
from examples.tacotron2.tacotron_dataset import CharactorMelDataset
32
from tensorflow_tts.configs import Tacotron2Config
33
from tensorflow_tts.models import TFTacotron2
34
35
36
@jit(nopython=True)
37
def get_duration_from_alignment(alignment):
38
D = np.array([0 for _ in range(np.shape(alignment)[0])])
39
40
for i in range(np.shape(alignment)[1]):
41
max_index = list(alignment[:, i]).index(alignment[:, i].max())
42
D[max_index] = D[max_index] + 1
43
44
return D
45
46
47
def main():
48
"""Running extract tacotron-2 durations."""
49
parser = argparse.ArgumentParser(
50
description="Extract durations from charactor with trained Tacotron-2 "
51
"(See detail in tensorflow_tts/example/tacotron-2/extract_duration.py)."
52
)
53
parser.add_argument(
54
"--rootdir",
55
default=None,
56
type=str,
57
required=True,
58
help="directory including ids/durations files.",
59
)
60
parser.add_argument(
61
"--outdir", type=str, required=True, help="directory to save generated speech."
62
)
63
parser.add_argument(
64
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
65
)
66
parser.add_argument(
67
"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."
68
)
69
parser.add_argument("--batch-size", default=8, type=int, help="batch size.")
70
parser.add_argument("--win-front", default=2, type=int, help="win-front.")
71
parser.add_argument("--win-back", default=2, type=int, help="win-front.")
72
parser.add_argument(
73
"--use-window-mask", default=1, type=int, help="toggle window masking."
74
)
75
parser.add_argument("--save-alignment", default=0, type=int, help="save-alignment.")
76
parser.add_argument(
77
"--config",
78
default=None,
79
type=str,
80
required=True,
81
help="yaml format configuration file. if not explicitly provided, "
82
"it will be searched in the checkpoint directory. (default=None)",
83
)
84
parser.add_argument(
85
"--verbose",
86
type=int,
87
default=1,
88
help="logging level. higher is more logging. (default=1)",
89
)
90
args = parser.parse_args()
91
92
# set logger
93
if args.verbose > 1:
94
logging.basicConfig(
95
level=logging.DEBUG,
96
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
97
)
98
elif args.verbose > 0:
99
logging.basicConfig(
100
level=logging.INFO,
101
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
102
)
103
else:
104
logging.basicConfig(
105
level=logging.WARN,
106
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
107
)
108
logging.warning("Skip DEBUG/INFO messages")
109
110
# check directory existence
111
if not os.path.exists(args.outdir):
112
os.makedirs(args.outdir)
113
114
# load config
115
with open(args.config) as f:
116
config = yaml.load(f, Loader=yaml.Loader)
117
config.update(vars(args))
118
119
if config["format"] == "npy":
120
char_query = "*-ids.npy"
121
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
122
char_load_fn = np.load
123
mel_load_fn = np.load
124
else:
125
raise ValueError("Only npy is supported.")
126
127
# define data-loader
128
dataset = CharactorMelDataset(
129
dataset=config["tacotron2_params"]["dataset"],
130
root_dir=args.rootdir,
131
charactor_query=char_query,
132
mel_query=mel_query,
133
charactor_load_fn=char_load_fn,
134
mel_load_fn=mel_load_fn,
135
reduction_factor=config["tacotron2_params"]["reduction_factor"],
136
use_fixed_shapes=True,
137
)
138
dataset = dataset.create(allow_cache=True, batch_size=args.batch_size, drop_remainder=False)
139
140
# define model and load checkpoint
141
tacotron2 = TFTacotron2(
142
config=Tacotron2Config(**config["tacotron2_params"]),
143
name="tacotron2",
144
)
145
tacotron2._build() # build model to be able load_weights.
146
tacotron2.load_weights(args.checkpoint)
147
148
# apply tf.function for tacotron2.
149
tacotron2 = tf.function(tacotron2, experimental_relax_shapes=True)
150
151
for data in tqdm(dataset, desc="[Extract Duration]"):
152
utt_ids = data["utt_ids"]
153
input_lengths = data["input_lengths"]
154
mel_lengths = data["mel_lengths"]
155
utt_ids = utt_ids.numpy()
156
real_mel_lengths = data["real_mel_lengths"]
157
del data["real_mel_lengths"]
158
159
# tacotron2 inference.
160
mel_outputs, post_mel_outputs, stop_outputs, alignment_historys = tacotron2(
161
**data,
162
use_window_mask=args.use_window_mask,
163
win_front=args.win_front,
164
win_back=args.win_back,
165
training=True,
166
)
167
168
# convert to numpy
169
alignment_historys = alignment_historys.numpy()
170
171
for i, alignment in enumerate(alignment_historys):
172
real_char_length = input_lengths[i].numpy()
173
real_mel_length = real_mel_lengths[i].numpy()
174
alignment_mel_length = int(
175
np.ceil(
176
real_mel_length / config["tacotron2_params"]["reduction_factor"]
177
)
178
)
179
alignment = alignment[:real_char_length, :alignment_mel_length]
180
d = get_duration_from_alignment(alignment) # [max_char_len]
181
182
d = d * config["tacotron2_params"]["reduction_factor"]
183
assert (
184
np.sum(d) >= real_mel_length
185
), f"{d}, {np.sum(d)}, {alignment_mel_length}, {real_mel_length}"
186
if np.sum(d) > real_mel_length:
187
rest = np.sum(d) - real_mel_length
188
# print(d, np.sum(d), real_mel_length)
189
if d[-1] > rest:
190
d[-1] -= rest
191
elif d[0] > rest:
192
d[0] -= rest
193
else:
194
d[-1] -= rest // 2
195
d[0] -= rest - rest // 2
196
197
assert d[-1] >= 0 and d[0] >= 0, f"{d}, {np.sum(d)}, {real_mel_length}"
198
199
saved_name = utt_ids[i].decode("utf-8")
200
201
# check a length compatible
202
assert (
203
len(d) == real_char_length
204
), f"different between len_char and len_durations, {len(d)} and {real_char_length}"
205
206
assert (
207
np.sum(d) == real_mel_length
208
), f"different between sum_durations and len_mel, {np.sum(d)} and {real_mel_length}"
209
210
# save D to folder.
211
np.save(
212
os.path.join(args.outdir, f"{saved_name}-durations.npy"),
213
d.astype(np.int32),
214
allow_pickle=False,
215
)
216
217
# save alignment to debug.
218
if args.save_alignment == 1:
219
figname = os.path.join(args.outdir, f"{saved_name}_alignment.png")
220
fig = plt.figure(figsize=(8, 6))
221
ax = fig.add_subplot(111)
222
ax.set_title(f"Alignment of {saved_name}")
223
im = ax.imshow(
224
alignment, aspect="auto", origin="lower", interpolation="none"
225
)
226
fig.colorbar(im, ax=ax)
227
xlabel = "Decoder timestep"
228
plt.xlabel(xlabel)
229
plt.ylabel("Encoder timestep")
230
plt.tight_layout()
231
plt.savefig(figname)
232
plt.close()
233
234
235
if __name__ == "__main__":
236
main()
237
238