Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/datasets/mel_dataset.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Dataset modules."""
16
17
import logging
18
import os
19
20
import numpy as np
21
import tensorflow as tf
22
23
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
24
from tensorflow_tts.utils import find_files
25
26
27
class MelDataset(AbstractDataset):
28
"""Tensorflow compatible mel dataset."""
29
30
def __init__(
31
self,
32
root_dir,
33
mel_query="*-raw-feats.h5",
34
mel_load_fn=np.load,
35
mel_length_threshold=0,
36
):
37
"""Initialize dataset.
38
39
Args:
40
root_dir (str): Root directory including dumped files.
41
mel_query (str): Query to find feature files in root_dir.
42
mel_load_fn (func): Function to load feature file.
43
mel_length_threshold (int): Threshold to remove short feature files.
44
45
"""
46
# find all of mel files.
47
mel_files = sorted(find_files(root_dir, mel_query))
48
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
49
50
# assert the number of files
51
assert len(mel_files) != 0, f"Not found any mel files in ${root_dir}."
52
53
if ".npy" in mel_query:
54
suffix = mel_query[1:]
55
utt_ids = [os.path.basename(f).replace(suffix, "") for f in mel_files]
56
57
# set global params
58
self.utt_ids = utt_ids
59
self.mel_files = mel_files
60
self.mel_lengths = mel_lengths
61
self.mel_load_fn = mel_load_fn
62
self.mel_length_threshold = mel_length_threshold
63
64
def get_args(self):
65
return [self.utt_ids]
66
67
def generator(self, utt_ids):
68
for i, utt_id in enumerate(utt_ids):
69
mel_file = self.mel_files[i]
70
mel = self.mel_load_fn(mel_file)
71
mel_length = self.mel_lengths[i]
72
73
items = {"utt_ids": utt_id, "mels": mel, "mel_lengths": mel_length}
74
75
yield items
76
77
def get_output_dtypes(self):
78
output_types = {
79
"utt_ids": tf.string,
80
"mels": tf.float32,
81
"mel_lengths": tf.int32,
82
}
83
return output_types
84
85
def create(
86
self,
87
allow_cache=False,
88
batch_size=1,
89
is_shuffle=False,
90
map_fn=None,
91
reshuffle_each_iteration=True,
92
):
93
"""Create tf.dataset function."""
94
output_types = self.get_output_dtypes()
95
datasets = tf.data.Dataset.from_generator(
96
self.generator, output_types=output_types, args=(self.get_args())
97
)
98
99
datasets = datasets.filter(
100
lambda x: x["mel_lengths"] > self.mel_length_threshold
101
)
102
103
if allow_cache:
104
datasets = datasets.cache()
105
106
if is_shuffle:
107
datasets = datasets.shuffle(
108
self.get_len_dataset(),
109
reshuffle_each_iteration=reshuffle_each_iteration,
110
)
111
112
# define padded shapes
113
padded_shapes = {
114
"utt_ids": [],
115
"mels": [None, 80],
116
"mel_lengths": [],
117
}
118
119
datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)
120
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
121
return datasets
122
123
def get_len_dataset(self):
124
return len(self.utt_ids)
125
126
def __name__(self):
127
return "MelDataset"
128
129