Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/fastspeech/fastspeech_dataset.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Dataset modules."""
16
17
import itertools
18
import logging
19
import os
20
import random
21
22
import numpy as np
23
import tensorflow as tf
24
25
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
26
from tensorflow_tts.utils import find_files
27
28
29
class CharactorDurationMelDataset(AbstractDataset):
30
"""Tensorflow Charactor Mel dataset."""
31
32
def __init__(
33
self,
34
root_dir,
35
charactor_query="*-ids.npy",
36
mel_query="*-norm-feats.npy",
37
duration_query="*-durations.npy",
38
charactor_load_fn=np.load,
39
mel_load_fn=np.load,
40
duration_load_fn=np.load,
41
mel_length_threshold=0,
42
):
43
"""Initialize dataset.
44
45
Args:
46
root_dir (str): Root directory including dumped files.
47
charactor_query (str): Query to find charactor files in root_dir.
48
mel_query (str): Query to find feature files in root_dir.
49
duration_query (str): Query to find duration files in root_dir.
50
charactor_load_fn (func): Function to load charactor file.
51
mel_load_fn (func): Function to load feature file.
52
duration_load_fn (func): Function to load duration file.
53
mel_length_threshold (int): Threshold to remove short feature files.
54
return_utt_id (bool): Whether to return the utterance id with arrays.
55
56
"""
57
# find all of charactor and mel files.
58
charactor_files = sorted(find_files(root_dir, charactor_query))
59
mel_files = sorted(find_files(root_dir, mel_query))
60
duration_files = sorted(find_files(root_dir, duration_query))
61
62
# assert the number of files
63
assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}."
64
assert (
65
len(mel_files) == len(charactor_files) == len(duration_files)
66
), f"Number of charactor, mel and duration files are different \
67
({len(mel_files)} vs {len(charactor_files)} vs {len(duration_files)})."
68
69
if ".npy" in charactor_query:
70
suffix = charactor_query[1:]
71
utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]
72
73
# set global params
74
self.utt_ids = utt_ids
75
self.mel_files = mel_files
76
self.charactor_files = charactor_files
77
self.duration_files = duration_files
78
self.mel_load_fn = mel_load_fn
79
self.charactor_load_fn = charactor_load_fn
80
self.duration_load_fn = duration_load_fn
81
self.mel_length_threshold = mel_length_threshold
82
83
def get_args(self):
84
return [self.utt_ids]
85
86
def generator(self, utt_ids):
87
for i, utt_id in enumerate(utt_ids):
88
mel_file = self.mel_files[i]
89
charactor_file = self.charactor_files[i]
90
duration_file = self.duration_files[i]
91
92
items = {
93
"utt_ids": utt_id,
94
"mel_files": mel_file,
95
"charactor_files": charactor_file,
96
"duration_files": duration_file,
97
}
98
99
yield items
100
101
@tf.function
102
def _load_data(self, items):
103
mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)
104
charactor = tf.numpy_function(np.load, [items["charactor_files"]], tf.int32)
105
duration = tf.numpy_function(np.load, [items["duration_files"]], tf.int32)
106
107
items = {
108
"utt_ids": items["utt_ids"],
109
"input_ids": charactor,
110
"speaker_ids": 0,
111
"duration_gts": duration,
112
"mel_gts": mel,
113
"mel_lengths": len(mel),
114
}
115
116
return items
117
118
def create(
119
self,
120
allow_cache=False,
121
batch_size=1,
122
is_shuffle=False,
123
map_fn=None,
124
reshuffle_each_iteration=True,
125
):
126
"""Create tf.dataset function."""
127
output_types = self.get_output_dtypes()
128
datasets = tf.data.Dataset.from_generator(
129
self.generator, output_types=output_types, args=(self.get_args())
130
)
131
132
# load data
133
datasets = datasets.map(
134
lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE
135
)
136
137
datasets = datasets.filter(
138
lambda x: x["mel_lengths"] > self.mel_length_threshold
139
)
140
141
if allow_cache:
142
datasets = datasets.cache()
143
144
if is_shuffle:
145
datasets = datasets.shuffle(
146
self.get_len_dataset(),
147
reshuffle_each_iteration=reshuffle_each_iteration,
148
)
149
150
# define padded_shapes
151
padded_shapes = {
152
"utt_ids": [],
153
"input_ids": [None],
154
"speaker_ids": [],
155
"duration_gts": [None],
156
"mel_gts": [None, None],
157
"mel_lengths": [],
158
}
159
160
datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)
161
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
162
return datasets
163
164
def get_output_dtypes(self):
165
output_types = {
166
"utt_ids": tf.string,
167
"mel_files": tf.string,
168
"charactor_files": tf.string,
169
"duration_files": tf.string,
170
}
171
return output_types
172
173
def get_len_dataset(self):
174
return len(self.utt_ids)
175
176
def __name__(self):
177
return "CharactorDurationMelDataset"
178
179
180
class CharactorDataset(AbstractDataset):
181
"""Tensorflow Charactor dataset."""
182
183
def __init__(
184
self, root_dir, charactor_query="*-ids.npy", charactor_load_fn=np.load,
185
):
186
"""Initialize dataset.
187
188
Args:
189
root_dir (str): Root directory including dumped files.
190
charactor_query (str): Query to find charactor files in root_dir.
191
charactor_load_fn (func): Function to load charactor file.
192
return_utt_id (bool): Whether to return the utterance id with arrays.
193
194
"""
195
# find all of charactor and mel files.
196
charactor_files = sorted(find_files(root_dir, charactor_query))
197
198
# assert the number of files
199
assert (
200
len(charactor_files) != 0
201
), f"Not found any char or duration files in ${root_dir}."
202
if ".npy" in charactor_query:
203
suffix = charactor_query[1:]
204
utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]
205
206
# set global params
207
self.utt_ids = utt_ids
208
self.charactor_files = charactor_files
209
self.charactor_load_fn = charactor_load_fn
210
211
def get_args(self):
212
return [self.utt_ids]
213
214
def generator(self, utt_ids):
215
for i, utt_id in enumerate(utt_ids):
216
charactor_file = self.charactor_files[i]
217
charactor = self.charactor_load_fn(charactor_file)
218
219
items = {"utt_ids": utt_id, "input_ids": charactor}
220
221
yield items
222
223
def create(
224
self,
225
allow_cache=False,
226
batch_size=1,
227
is_shuffle=False,
228
map_fn=None,
229
reshuffle_each_iteration=True,
230
):
231
"""Create tf.dataset function."""
232
output_types = self.get_output_dtypes()
233
datasets = tf.data.Dataset.from_generator(
234
self.generator, output_types=output_types, args=(self.get_args())
235
)
236
237
if allow_cache:
238
datasets = datasets.cache()
239
240
if is_shuffle:
241
datasets = datasets.shuffle(
242
self.get_len_dataset(),
243
reshuffle_each_iteration=reshuffle_each_iteration,
244
)
245
246
# define padded shapes
247
padded_shapes = {"utt_ids": [], "input_ids": [None]}
248
249
datasets = datasets.padded_batch(
250
batch_size, padded_shapes=padded_shapes, drop_remainder=True
251
)
252
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
253
return datasets
254
255
def get_output_dtypes(self):
256
output_types = {"utt_ids": tf.string, "input_ids": tf.int32}
257
return output_types
258
259
def get_len_dataset(self):
260
return len(self.utt_ids)
261
262
def __name__(self):
263
return "CharactorDataset"
264
265