Path: blob/master/tensorflow_tts/datasets/abstract_dataset.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Abstract Dataset modules."""1516import abc1718import tensorflow as tf192021class AbstractDataset(metaclass=abc.ABCMeta):22"""Abstract Dataset module for Dataset Loader."""2324@abc.abstractmethod25def get_args(self):26"""Return args for generator function."""27pass2829@abc.abstractmethod30def generator(self):31"""Generator function, should have args from get_args function."""32pass3334@abc.abstractmethod35def get_output_dtypes(self):36"""Return output dtypes for each element from generator."""37pass3839@abc.abstractmethod40def get_len_dataset(self):41"""Return number of samples on dataset."""42pass4344def create(45self,46allow_cache=False,47batch_size=1,48is_shuffle=False,49map_fn=None,50reshuffle_each_iteration=True,51):52"""Create tf.dataset function."""53output_types = self.get_output_dtypes()54datasets = tf.data.Dataset.from_generator(55self.generator, output_types=output_types, args=(self.get_args())56)5758if allow_cache:59datasets = datasets.cache()6061if is_shuffle:62datasets = datasets.shuffle(63self.get_len_dataset(),64reshuffle_each_iteration=reshuffle_each_iteration,65)6667if batch_size > 1 and map_fn is None:68raise ValueError("map function must define when batch_size > 1.")6970if map_fn is not None:71datasets = datasets.map(map_fn, tf.data.experimental.AUTOTUNE)7273datasets = datasets.batch(batch_size)74datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)7576return datasets777879