Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/examples/melgan/audio_mel_dataset.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Dataset modules."""
16
17
import logging
18
import os
19
20
import numpy as np
21
import tensorflow as tf
22
23
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
24
from tensorflow_tts.utils import find_files
25
26
27
class AudioMelDataset(AbstractDataset):
28
"""Tensorflow Audio Mel dataset."""
29
30
def __init__(
31
self,
32
root_dir,
33
audio_query="*-wave.npy",
34
mel_query="*-raw-feats.npy",
35
audio_load_fn=np.load,
36
mel_load_fn=np.load,
37
audio_length_threshold=0,
38
mel_length_threshold=0,
39
):
40
"""Initialize dataset.
41
Args:
42
root_dir (str): Root directory including dumped files.
43
audio_query (str): Query to find audio files in root_dir.
44
mel_query (str): Query to find feature files in root_dir.
45
audio_load_fn (func): Function to load audio file.
46
mel_load_fn (func): Function to load feature file.
47
audio_length_threshold (int): Threshold to remove short audio files.
48
mel_length_threshold (int): Threshold to remove short feature files.
49
return_utt_id (bool): Whether to return the utterance id with arrays.
50
"""
51
# find all of audio and mel files.
52
audio_files = sorted(find_files(root_dir, audio_query))
53
mel_files = sorted(find_files(root_dir, mel_query))
54
55
# assert the number of files
56
assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}."
57
assert len(audio_files) == len(
58
mel_files
59
), f"Number of audio and mel files are different ({len(audio_files)} vs {len(mel_files)})."
60
61
if ".npy" in audio_query:
62
suffix = audio_query[1:]
63
utt_ids = [os.path.basename(f).replace(suffix, "") for f in audio_files]
64
65
# set global params
66
self.utt_ids = utt_ids
67
self.audio_files = audio_files
68
self.mel_files = mel_files
69
self.audio_load_fn = audio_load_fn
70
self.mel_load_fn = mel_load_fn
71
self.audio_length_threshold = audio_length_threshold
72
self.mel_length_threshold = mel_length_threshold
73
74
def get_args(self):
75
return [self.utt_ids]
76
77
def generator(self, utt_ids):
78
for i, utt_id in enumerate(utt_ids):
79
audio_file = self.audio_files[i]
80
mel_file = self.mel_files[i]
81
82
items = {
83
"utt_ids": utt_id,
84
"audio_files": audio_file,
85
"mel_files": mel_file,
86
}
87
88
yield items
89
90
@tf.function
91
def _load_data(self, items):
92
audio = tf.numpy_function(np.load, [items["audio_files"]], tf.float32)
93
mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)
94
95
items = {
96
"utt_ids": items["utt_ids"],
97
"audios": audio,
98
"mels": mel,
99
"mel_lengths": len(mel),
100
"audio_lengths": len(audio),
101
}
102
103
return items
104
105
def create(
106
self,
107
allow_cache=False,
108
batch_size=1,
109
is_shuffle=False,
110
map_fn=None,
111
reshuffle_each_iteration=True,
112
):
113
"""Create tf.dataset function."""
114
output_types = self.get_output_dtypes()
115
datasets = tf.data.Dataset.from_generator(
116
self.generator, output_types=output_types, args=(self.get_args())
117
)
118
options = tf.data.Options()
119
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
120
datasets = datasets.with_options(options)
121
# load dataset
122
datasets = datasets.map(
123
lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE
124
)
125
126
datasets = datasets.filter(
127
lambda x: x["mel_lengths"] > self.mel_length_threshold
128
)
129
datasets = datasets.filter(
130
lambda x: x["audio_lengths"] > self.audio_length_threshold
131
)
132
133
if allow_cache:
134
datasets = datasets.cache()
135
136
if is_shuffle:
137
datasets = datasets.shuffle(
138
self.get_len_dataset(),
139
reshuffle_each_iteration=reshuffle_each_iteration,
140
)
141
142
if batch_size > 1 and map_fn is None:
143
raise ValueError("map function must define when batch_size > 1.")
144
145
if map_fn is not None:
146
datasets = datasets.map(map_fn, tf.data.experimental.AUTOTUNE)
147
148
# define padded shapes
149
padded_shapes = {
150
"utt_ids": [],
151
"audios": [None],
152
"mels": [None, 80],
153
"mel_lengths": [],
154
"audio_lengths": [],
155
}
156
157
# define padded values
158
padding_values = {
159
"utt_ids": "",
160
"audios": 0.0,
161
"mels": 0.0,
162
"mel_lengths": 0,
163
"audio_lengths": 0,
164
}
165
166
datasets = datasets.padded_batch(
167
batch_size,
168
padded_shapes=padded_shapes,
169
padding_values=padding_values,
170
drop_remainder=True,
171
)
172
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
173
return datasets
174
175
def get_output_dtypes(self):
176
output_types = {
177
"utt_ids": tf.string,
178
"audio_files": tf.string,
179
"mel_files": tf.string,
180
}
181
return output_types
182
183
def get_len_dataset(self):
184
return len(self.utt_ids)
185
186
def __name__(self):
187
return "AudioMelDataset"
188
189