Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/datasets/abstract_dataset.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Abstract Dataset modules."""
16
17
import abc
18
19
import tensorflow as tf
20
21
22
class AbstractDataset(metaclass=abc.ABCMeta):
23
"""Abstract Dataset module for Dataset Loader."""
24
25
@abc.abstractmethod
26
def get_args(self):
27
"""Return args for generator function."""
28
pass
29
30
@abc.abstractmethod
31
def generator(self):
32
"""Generator function, should have args from get_args function."""
33
pass
34
35
@abc.abstractmethod
36
def get_output_dtypes(self):
37
"""Return output dtypes for each element from generator."""
38
pass
39
40
@abc.abstractmethod
41
def get_len_dataset(self):
42
"""Return number of samples on dataset."""
43
pass
44
45
def create(
46
self,
47
allow_cache=False,
48
batch_size=1,
49
is_shuffle=False,
50
map_fn=None,
51
reshuffle_each_iteration=True,
52
):
53
"""Create tf.dataset function."""
54
output_types = self.get_output_dtypes()
55
datasets = tf.data.Dataset.from_generator(
56
self.generator, output_types=output_types, args=(self.get_args())
57
)
58
59
if allow_cache:
60
datasets = datasets.cache()
61
62
if is_shuffle:
63
datasets = datasets.shuffle(
64
self.get_len_dataset(),
65
reshuffle_each_iteration=reshuffle_each_iteration,
66
)
67
68
if batch_size > 1 and map_fn is None:
69
raise ValueError("map function must define when batch_size > 1.")
70
71
if map_fn is not None:
72
datasets = datasets.map(map_fn, tf.data.experimental.AUTOTUNE)
73
74
datasets = datasets.batch(batch_size)
75
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
76
77
return datasets
78
79