Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/datasets/audio_dataset.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Audio modules."""
16
17
import logging
18
import os
19
20
import numpy as np
21
import tensorflow as tf
22
23
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
24
from tensorflow_tts.utils import find_files
25
26
27
class AudioDataset(AbstractDataset):
28
"""Tensorflow compatible audio dataset."""
29
30
def __init__(
31
self,
32
root_dir,
33
audio_query="*-wave.npy",
34
audio_load_fn=np.load,
35
audio_length_threshold=0,
36
):
37
"""Initialize dataset.
38
39
Args:
40
root_dir (str): Root directory including dumped files.
41
audio_query (str): Query to find feature files in root_dir.
42
audio_load_fn (func): Function to load feature file.
43
audio_length_threshold (int): Threshold to remove short feature files.
44
return_utt_id (bool): Whether to return the utterance id with arrays.
45
46
"""
47
# find all of mel files.
48
audio_files = sorted(find_files(root_dir, audio_query))
49
audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
50
51
# assert the number of files
52
assert len(audio_files) != 0, f"Not found any mel files in ${root_dir}."
53
54
if ".npy" in audio_query:
55
suffix = audio_query[1:]
56
utt_ids = [os.path.basename(f).replace(suffix, "") for f in audio_files]
57
58
# set global params
59
self.utt_ids = utt_ids
60
self.audio_files = audio_files
61
self.audio_lengths = audio_lengths
62
self.audio_load_fn = audio_load_fn
63
self.audio_length_threshold = audio_length_threshold
64
65
def get_args(self):
66
return [self.utt_ids]
67
68
def generator(self, utt_ids):
69
for i, utt_id in enumerate(utt_ids):
70
audio_file = self.audio_files[i]
71
audio = self.audio_load_fn(audio_file)
72
audio_length = self.audio_lengths[i]
73
74
items = {"utt_ids": utt_id, "audios": audio, "audio_lengths": audio_length}
75
76
yield items
77
78
def get_output_dtypes(self):
79
output_types = {
80
"utt_ids": tf.string,
81
"audios": tf.float32,
82
"audio_lengths": tf.float32,
83
}
84
return output_types
85
86
def create(
87
self,
88
allow_cache=False,
89
batch_size=1,
90
is_shuffle=False,
91
map_fn=None,
92
reshuffle_each_iteration=True,
93
):
94
"""Create tf.dataset function."""
95
output_types = self.get_output_dtypes()
96
datasets = tf.data.Dataset.from_generator(
97
self.generator, output_types=output_types, args=(self.get_args())
98
)
99
100
datasets = datasets.filter(
101
lambda x: x["audio_lengths"] > self.audio_length_threshold
102
)
103
104
if allow_cache:
105
datasets = datasets.cache()
106
107
if is_shuffle:
108
datasets = datasets.shuffle(
109
self.get_len_dataset(),
110
reshuffle_each_iteration=reshuffle_each_iteration,
111
)
112
113
# define padded shapes
114
padded_shapes = {
115
"utt_ids": [],
116
"audios": [None],
117
"audio_lengths": [],
118
}
119
120
datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)
121
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
122
return datasets
123
124
def get_len_dataset(self):
125
return len(self.utt_ids)
126
127
def __name__(self):
128
return "AudioDataset"
129
130