Path: blob/master/examples/melgan/audio_mel_dataset.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Dataset modules."""1516import logging17import os1819import numpy as np20import tensorflow as tf2122from tensorflow_tts.datasets.abstract_dataset import AbstractDataset23from tensorflow_tts.utils import find_files242526class AudioMelDataset(AbstractDataset):27"""Tensorflow Audio Mel dataset."""2829def __init__(30self,31root_dir,32audio_query="*-wave.npy",33mel_query="*-raw-feats.npy",34audio_load_fn=np.load,35mel_load_fn=np.load,36audio_length_threshold=0,37mel_length_threshold=0,38):39"""Initialize dataset.40Args:41root_dir (str): Root directory including dumped files.42audio_query (str): Query to find audio files in root_dir.43mel_query (str): Query to find feature files in root_dir.44audio_load_fn (func): Function to load audio file.45mel_load_fn (func): Function to load feature file.46audio_length_threshold (int): Threshold to remove short audio files.47mel_length_threshold (int): Threshold to remove short feature files.48return_utt_id (bool): Whether to return the utterance id with arrays.49"""50# find all of audio and mel files.51audio_files = sorted(find_files(root_dir, audio_query))52mel_files = sorted(find_files(root_dir, mel_query))5354# assert the number of files55assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}."56assert len(audio_files) == len(57mel_files58), f"Number of audio and mel files are different ({len(audio_files)} vs {len(mel_files)})."5960if ".npy" in audio_query:61suffix = audio_query[1:]62utt_ids = [os.path.basename(f).replace(suffix, "") for f in audio_files]6364# set global params65self.utt_ids = utt_ids66self.audio_files = audio_files67self.mel_files = mel_files68self.audio_load_fn = audio_load_fn69self.mel_load_fn = mel_load_fn70self.audio_length_threshold = audio_length_threshold71self.mel_length_threshold = mel_length_threshold7273def get_args(self):74return [self.utt_ids]7576def generator(self, utt_ids):77for i, utt_id in enumerate(utt_ids):78audio_file = self.audio_files[i]79mel_file = self.mel_files[i]8081items = {82"utt_ids": utt_id,83"audio_files": audio_file,84"mel_files": mel_file,85}8687yield items8889@tf.function90def _load_data(self, items):91audio = tf.numpy_function(np.load, [items["audio_files"]], tf.float32)92mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)9394items = {95"utt_ids": items["utt_ids"],96"audios": audio,97"mels": mel,98"mel_lengths": len(mel),99"audio_lengths": len(audio),100}101102return items103104def create(105self,106allow_cache=False,107batch_size=1,108is_shuffle=False,109map_fn=None,110reshuffle_each_iteration=True,111):112"""Create tf.dataset function."""113output_types = self.get_output_dtypes()114datasets = tf.data.Dataset.from_generator(115self.generator, output_types=output_types, args=(self.get_args())116)117options = tf.data.Options()118options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF119datasets = datasets.with_options(options)120# load dataset121datasets = datasets.map(122lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE123)124125datasets = datasets.filter(126lambda x: x["mel_lengths"] > self.mel_length_threshold127)128datasets = datasets.filter(129lambda x: x["audio_lengths"] > self.audio_length_threshold130)131132if allow_cache:133datasets = datasets.cache()134135if is_shuffle:136datasets = datasets.shuffle(137self.get_len_dataset(),138reshuffle_each_iteration=reshuffle_each_iteration,139)140141if batch_size > 1 and map_fn is None:142raise ValueError("map function must define when batch_size > 1.")143144if map_fn is not None:145datasets = datasets.map(map_fn, tf.data.experimental.AUTOTUNE)146147# define padded shapes148padded_shapes = {149"utt_ids": [],150"audios": [None],151"mels": [None, 80],152"mel_lengths": [],153"audio_lengths": [],154}155156# define padded values157padding_values = {158"utt_ids": "",159"audios": 0.0,160"mels": 0.0,161"mel_lengths": 0,162"audio_lengths": 0,163}164165datasets = datasets.padded_batch(166batch_size,167padded_shapes=padded_shapes,168padding_values=padding_values,169drop_remainder=True,170)171datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)172return datasets173174def get_output_dtypes(self):175output_types = {176"utt_ids": tf.string,177"audio_files": tf.string,178"mel_files": tf.string,179}180return output_types181182def get_len_dataset(self):183return len(self.utt_ids)184185def __name__(self):186return "AudioMelDataset"187188189