Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/tacotron2/tacotron_dataset.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Tacotron Related Dataset modules."""
16
17
import itertools
18
import logging
19
import os
20
import random
21
22
import numpy as np
23
import tensorflow as tf
24
25
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
26
from tensorflow_tts.utils import find_files
27
28
29
class CharactorMelDataset(AbstractDataset):
30
"""Tensorflow Charactor Mel dataset."""
31
32
def __init__(
33
self,
34
dataset,
35
root_dir,
36
charactor_query="*-ids.npy",
37
mel_query="*-norm-feats.npy",
38
align_query="",
39
charactor_load_fn=np.load,
40
mel_load_fn=np.load,
41
mel_length_threshold=0,
42
reduction_factor=1,
43
mel_pad_value=0.0,
44
char_pad_value=0,
45
ga_pad_value=-1.0,
46
g=0.2,
47
use_fixed_shapes=False,
48
):
49
"""Initialize dataset.
50
51
Args:
52
root_dir (str): Root directory including dumped files.
53
charactor_query (str): Query to find charactor files in root_dir.
54
mel_query (str): Query to find feature files in root_dir.
55
charactor_load_fn (func): Function to load charactor file.
56
align_query (str): Query to find FAL files in root_dir. If empty, we use stock guided attention loss
57
mel_load_fn (func): Function to load feature file.
58
mel_length_threshold (int): Threshold to remove short feature files.
59
reduction_factor (int): Reduction factor on Tacotron-2 paper.
60
mel_pad_value (float): Padding value for mel-spectrogram.
61
char_pad_value (int): Padding value for charactor.
62
ga_pad_value (float): Padding value for guided attention.
63
g (float): G value for guided attention.
64
use_fixed_shapes (bool): Use fixed shape for mel targets or not.
65
max_char_length (int): maximum charactor length if use_fixed_shapes=True.
66
max_mel_length (int): maximum mel length if use_fixed_shapes=True
67
68
"""
69
# find all of charactor and mel files.
70
charactor_files = sorted(find_files(root_dir, charactor_query))
71
mel_files = sorted(find_files(root_dir, mel_query))
72
73
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
74
char_lengths = [charactor_load_fn(f).shape[0] for f in charactor_files]
75
76
# assert the number of files
77
assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}."
78
assert (
79
len(mel_files) == len(charactor_files) == len(mel_lengths)
80
), f"Number of charactor, mel and duration files are different \
81
({len(mel_files)} vs {len(charactor_files)} vs {len(mel_lengths)})."
82
83
self.align_files = []
84
85
if len(align_query) > 1:
86
align_files = sorted(find_files(root_dir, align_query))
87
assert len(align_files) == len(
88
mel_files
89
), f"Number of align files ({len(align_files)}) and mel files ({len(mel_files)}) are different"
90
logging.info("Using FAL loss")
91
self.align_files = align_files
92
else:
93
logging.info("Using guided attention loss")
94
95
if ".npy" in charactor_query:
96
suffix = charactor_query[1:]
97
utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]
98
99
# set global params
100
self.utt_ids = utt_ids
101
self.mel_files = mel_files
102
self.charactor_files = charactor_files
103
self.mel_load_fn = mel_load_fn
104
self.charactor_load_fn = charactor_load_fn
105
self.mel_lengths = mel_lengths
106
self.char_lengths = char_lengths
107
self.reduction_factor = reduction_factor
108
self.mel_length_threshold = mel_length_threshold
109
self.mel_pad_value = mel_pad_value
110
self.char_pad_value = char_pad_value
111
self.ga_pad_value = ga_pad_value
112
self.g = g
113
self.use_fixed_shapes = use_fixed_shapes
114
self.max_char_length = np.max(char_lengths)
115
116
if np.max(mel_lengths) % self.reduction_factor == 0:
117
self.max_mel_length = np.max(mel_lengths)
118
else:
119
self.max_mel_length = (
120
np.max(mel_lengths)
121
+ self.reduction_factor
122
- np.max(mel_lengths) % self.reduction_factor
123
)
124
125
def get_args(self):
126
return [self.utt_ids]
127
128
def generator(self, utt_ids):
129
for i, utt_id in enumerate(utt_ids):
130
mel_file = self.mel_files[i]
131
charactor_file = self.charactor_files[i]
132
align_file = self.align_files[i] if len(self.align_files) > 1 else ""
133
134
items = {
135
"utt_ids": utt_id,
136
"mel_files": mel_file,
137
"charactor_files": charactor_file,
138
"align_files": align_file,
139
}
140
141
yield items
142
143
@tf.function
144
def _load_data(self, items):
145
mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)
146
charactor = tf.numpy_function(np.load, [items["charactor_files"]], tf.int32)
147
g_att = (
148
tf.numpy_function(np.load, [items["align_files"]], tf.float32)
149
if len(self.align_files) > 1
150
else None
151
)
152
153
mel_length = len(mel)
154
char_length = len(charactor)
155
# padding mel to make its length is multiple of reduction factor.
156
real_mel_length = mel_length
157
remainder = mel_length % self.reduction_factor
158
if remainder != 0:
159
new_mel_length = mel_length + self.reduction_factor - remainder
160
mel = tf.pad(
161
mel,
162
[[0, new_mel_length - mel_length], [0, 0]],
163
constant_values=self.mel_pad_value,
164
)
165
mel_length = new_mel_length
166
167
items = {
168
"utt_ids": items["utt_ids"],
169
"input_ids": charactor,
170
"input_lengths": char_length,
171
"speaker_ids": 0,
172
"mel_gts": mel,
173
"mel_lengths": mel_length,
174
"real_mel_lengths": real_mel_length,
175
"g_attentions": g_att,
176
}
177
178
return items
179
180
def _guided_attention(self, items):
181
"""Guided attention. Refer to page 3 on the paper (https://arxiv.org/abs/1710.08969)."""
182
items = items.copy()
183
mel_len = items["mel_lengths"] // self.reduction_factor
184
char_len = items["input_lengths"]
185
xv, yv = tf.meshgrid(tf.range(char_len), tf.range(mel_len), indexing="ij")
186
f32_matrix = tf.cast(yv / mel_len - xv / char_len, tf.float32)
187
items["g_attentions"] = 1.0 - tf.math.exp(
188
-(f32_matrix ** 2) / (2 * self.g ** 2)
189
)
190
return items
191
192
def create(
193
self,
194
allow_cache=False,
195
batch_size=1,
196
is_shuffle=False,
197
map_fn=None,
198
reshuffle_each_iteration=True,
199
drop_remainder=True,
200
):
201
"""Create tf.dataset function."""
202
output_types = self.get_output_dtypes()
203
datasets = tf.data.Dataset.from_generator(
204
self.generator, output_types=output_types, args=(self.get_args())
205
)
206
207
# load data
208
datasets = datasets.map(
209
lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE
210
)
211
212
# calculate guided attention
213
if len(self.align_files) < 1:
214
datasets = datasets.map(
215
lambda items: self._guided_attention(items),
216
tf.data.experimental.AUTOTUNE,
217
)
218
219
datasets = datasets.filter(
220
lambda x: x["mel_lengths"] > self.mel_length_threshold
221
)
222
223
if allow_cache:
224
datasets = datasets.cache()
225
226
if is_shuffle:
227
datasets = datasets.shuffle(
228
self.get_len_dataset(),
229
reshuffle_each_iteration=reshuffle_each_iteration,
230
)
231
232
# define padding value.
233
padding_values = {
234
"utt_ids": " ",
235
"input_ids": self.char_pad_value,
236
"input_lengths": 0,
237
"speaker_ids": 0,
238
"mel_gts": self.mel_pad_value,
239
"mel_lengths": 0,
240
"real_mel_lengths": 0,
241
"g_attentions": self.ga_pad_value,
242
}
243
244
# define padded shapes.
245
padded_shapes = {
246
"utt_ids": [],
247
"input_ids": [None]
248
if self.use_fixed_shapes is False
249
else [self.max_char_length],
250
"input_lengths": [],
251
"speaker_ids": [],
252
"mel_gts": [None, 80]
253
if self.use_fixed_shapes is False
254
else [self.max_mel_length, 80],
255
"mel_lengths": [],
256
"real_mel_lengths": [],
257
"g_attentions": [None, None]
258
if self.use_fixed_shapes is False
259
else [self.max_char_length, self.max_mel_length // self.reduction_factor],
260
}
261
262
datasets = datasets.padded_batch(
263
batch_size,
264
padded_shapes=padded_shapes,
265
padding_values=padding_values,
266
drop_remainder=drop_remainder,
267
)
268
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
269
return datasets
270
271
def get_output_dtypes(self):
272
output_types = {
273
"utt_ids": tf.string,
274
"mel_files": tf.string,
275
"charactor_files": tf.string,
276
"align_files": tf.string,
277
}
278
return output_types
279
280
def get_len_dataset(self):
281
return len(self.utt_ids)
282
283
def __name__(self):
284
return "CharactorMelDataset"
285
286