Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/tacotron2/export_align.py
1558 views
1
import os
2
import shutil
3
from tqdm import tqdm
4
import argparse
5
6
from scipy.ndimage import zoom
7
from skimage.data import camera
8
import numpy as np
9
from scipy.spatial.distance import cdist
10
11
12
def safemkdir(dirn):
13
if not os.path.isdir(dirn):
14
os.mkdir(dirn)
15
16
17
from pathlib import Path
18
19
20
def duration_to_alignment(in_duration):
21
total_len = np.sum(in_duration)
22
num_chars = len(in_duration)
23
24
attention = np.zeros(shape=(num_chars, total_len), dtype=np.float32)
25
y_offset = 0
26
27
for duration_idx, duration_val in enumerate(in_duration):
28
for y_val in range(0, duration_val):
29
attention[duration_idx][y_offset + y_val] = 1.0
30
31
y_offset += duration_val
32
33
return attention
34
35
36
def rescale_alignment(in_alignment, in_targcharlen):
37
current_x = in_alignment.shape[0]
38
x_ratio = in_targcharlen / current_x
39
pivot_points = []
40
41
zoomed = zoom(in_alignment, (x_ratio, 1.0), mode="nearest")
42
43
for x_v in range(0, zoomed.shape[0]):
44
for y_v in range(0, zoomed.shape[1]):
45
val = zoomed[x_v][y_v]
46
if val < 0.5:
47
val = 0.0
48
else:
49
val = 1.0
50
pivot_points.append((x_v, y_v))
51
52
zoomed[x_v][y_v] = val
53
54
if zoomed.shape[0] != in_targcharlen:
55
print("Zooming didn't rshape well, explicitly reshaping")
56
zoomed.resize((in_targcharlen, in_alignment.shape[1]))
57
58
return zoomed, pivot_points
59
60
61
def gather_dist(in_mtr, in_points):
62
# initialize with known size for fast
63
full_coords = [(0, 0) for x in range(in_mtr.shape[0] * in_mtr.shape[1])]
64
i = 0
65
for x in range(0, in_mtr.shape[0]):
66
for y in range(0, in_mtr.shape[1]):
67
full_coords[i] = (x, y)
68
i += 1
69
70
return cdist(full_coords, in_points, "euclidean")
71
72
73
def create_guided(in_align, in_pvt, looseness):
74
new_att = np.ones(in_align.shape, dtype=np.float32)
75
# It is dramatically faster that we first gather all the points and calculate than do it manually
76
# for each point in for loop
77
dist_arr = gather_dist(in_align, in_pvt)
78
# Scale looseness based on attention size. (addition works better than mul). Also divide by 100
79
# because having user input 3.35 is nicer
80
real_loose = (looseness / 100) * (new_att.shape[0] + new_att.shape[1])
81
g_idx = 0
82
for x in range(0, new_att.shape[0]):
83
for y in range(0, new_att.shape[1]):
84
min_point_idx = dist_arr[g_idx].argmin()
85
86
closest_pvt = in_pvt[min_point_idx]
87
distance = dist_arr[g_idx][min_point_idx] / real_loose
88
distance = np.power(distance, 2)
89
90
g_idx += 1
91
92
new_att[x, y] = distance
93
94
return np.clip(new_att, 0.0, 1.0)
95
96
97
def get_pivot_points(in_att):
98
ret_points = []
99
for x in range(0, in_att.shape[0]):
100
for y in range(0, in_att.shape[1]):
101
if in_att[x, y] > 0.8:
102
ret_points.append((x, y))
103
return ret_points
104
105
106
def main():
107
parser = argparse.ArgumentParser(
108
description="Postprocess durations to become alignments"
109
)
110
parser.add_argument(
111
"--dump-dir",
112
default="dump",
113
type=str,
114
help="Path of dump directory",
115
)
116
parser.add_argument(
117
"--looseness",
118
default=3.5,
119
type=float,
120
help="Looseness of the generated guided attention map. Lower values = tighter",
121
)
122
args = parser.parse_args()
123
dump_dir = args.dump_dir
124
dump_sets = ["train", "valid"]
125
126
for d_set in dump_sets:
127
full_fol = os.path.join(dump_dir, d_set)
128
align_path = os.path.join(full_fol, "alignments")
129
130
ids_path = os.path.join(full_fol, "ids")
131
durations_path = os.path.join(full_fol, "durations")
132
133
safemkdir(align_path)
134
135
for duration_fn in tqdm(os.listdir(durations_path)):
136
if not ".npy" in duration_fn:
137
continue
138
139
id_fn = duration_fn.replace("-durations", "-ids")
140
141
id_path = os.path.join(ids_path, id_fn)
142
duration_path = os.path.join(durations_path, duration_fn)
143
144
duration_arr = np.load(duration_path)
145
id_arr = np.load(id_path)
146
147
id_true_size = len(id_arr)
148
149
align = duration_to_alignment(duration_arr)
150
151
if align.shape[0] != id_true_size:
152
align, points = rescale_alignment(align, id_true_size)
153
else:
154
points = get_pivot_points(align)
155
156
if len(points) == 0:
157
print("WARNING points are empty for", id_fn)
158
159
align = create_guided(align, points, args.looseness)
160
161
align_fn = id_fn.replace("-ids", "-alignment")
162
align_full_fn = os.path.join(align_path, align_fn)
163
164
np.save(align_full_fn, align.astype("float32"))
165
166
167
if __name__ == "__main__":
168
main()
169
170