Path: blob/master/tensorflow_tts/datasets/audio_dataset.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Audio modules."""1516import logging17import os1819import numpy as np20import tensorflow as tf2122from tensorflow_tts.datasets.abstract_dataset import AbstractDataset23from tensorflow_tts.utils import find_files242526class AudioDataset(AbstractDataset):27"""Tensorflow compatible audio dataset."""2829def __init__(30self,31root_dir,32audio_query="*-wave.npy",33audio_load_fn=np.load,34audio_length_threshold=0,35):36"""Initialize dataset.3738Args:39root_dir (str): Root directory including dumped files.40audio_query (str): Query to find feature files in root_dir.41audio_load_fn (func): Function to load feature file.42audio_length_threshold (int): Threshold to remove short feature files.43return_utt_id (bool): Whether to return the utterance id with arrays.4445"""46# find all of mel files.47audio_files = sorted(find_files(root_dir, audio_query))48audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]4950# assert the number of files51assert len(audio_files) != 0, f"Not found any mel files in ${root_dir}."5253if ".npy" in audio_query:54suffix = audio_query[1:]55utt_ids = [os.path.basename(f).replace(suffix, "") for f in audio_files]5657# set global params58self.utt_ids = utt_ids59self.audio_files = audio_files60self.audio_lengths = audio_lengths61self.audio_load_fn = audio_load_fn62self.audio_length_threshold = audio_length_threshold6364def get_args(self):65return [self.utt_ids]6667def generator(self, utt_ids):68for i, utt_id in enumerate(utt_ids):69audio_file = self.audio_files[i]70audio = self.audio_load_fn(audio_file)71audio_length = self.audio_lengths[i]7273items = {"utt_ids": utt_id, "audios": audio, "audio_lengths": audio_length}7475yield items7677def get_output_dtypes(self):78output_types = {79"utt_ids": tf.string,80"audios": tf.float32,81"audio_lengths": tf.float32,82}83return output_types8485def create(86self,87allow_cache=False,88batch_size=1,89is_shuffle=False,90map_fn=None,91reshuffle_each_iteration=True,92):93"""Create tf.dataset function."""94output_types = self.get_output_dtypes()95datasets = tf.data.Dataset.from_generator(96self.generator, output_types=output_types, args=(self.get_args())97)9899datasets = datasets.filter(100lambda x: x["audio_lengths"] > self.audio_length_threshold101)102103if allow_cache:104datasets = datasets.cache()105106if is_shuffle:107datasets = datasets.shuffle(108self.get_len_dataset(),109reshuffle_each_iteration=reshuffle_each_iteration,110)111112# define padded shapes113padded_shapes = {114"utt_ids": [],115"audios": [None],116"audio_lengths": [],117}118119datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)120datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)121return datasets122123def get_len_dataset(self):124return len(self.utt_ids)125126def __name__(self):127return "AudioDataset"128129130