Path: blob/master/tensorflow_tts/datasets/mel_dataset.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Dataset modules."""1516import logging17import os1819import numpy as np20import tensorflow as tf2122from tensorflow_tts.datasets.abstract_dataset import AbstractDataset23from tensorflow_tts.utils import find_files242526class MelDataset(AbstractDataset):27"""Tensorflow compatible mel dataset."""2829def __init__(30self,31root_dir,32mel_query="*-raw-feats.h5",33mel_load_fn=np.load,34mel_length_threshold=0,35):36"""Initialize dataset.3738Args:39root_dir (str): Root directory including dumped files.40mel_query (str): Query to find feature files in root_dir.41mel_load_fn (func): Function to load feature file.42mel_length_threshold (int): Threshold to remove short feature files.4344"""45# find all of mel files.46mel_files = sorted(find_files(root_dir, mel_query))47mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]4849# assert the number of files50assert len(mel_files) != 0, f"Not found any mel files in ${root_dir}."5152if ".npy" in mel_query:53suffix = mel_query[1:]54utt_ids = [os.path.basename(f).replace(suffix, "") for f in mel_files]5556# set global params57self.utt_ids = utt_ids58self.mel_files = mel_files59self.mel_lengths = mel_lengths60self.mel_load_fn = mel_load_fn61self.mel_length_threshold = mel_length_threshold6263def get_args(self):64return [self.utt_ids]6566def generator(self, utt_ids):67for i, utt_id in enumerate(utt_ids):68mel_file = self.mel_files[i]69mel = self.mel_load_fn(mel_file)70mel_length = self.mel_lengths[i]7172items = {"utt_ids": utt_id, "mels": mel, "mel_lengths": mel_length}7374yield items7576def get_output_dtypes(self):77output_types = {78"utt_ids": tf.string,79"mels": tf.float32,80"mel_lengths": tf.int32,81}82return output_types8384def create(85self,86allow_cache=False,87batch_size=1,88is_shuffle=False,89map_fn=None,90reshuffle_each_iteration=True,91):92"""Create tf.dataset function."""93output_types = self.get_output_dtypes()94datasets = tf.data.Dataset.from_generator(95self.generator, output_types=output_types, args=(self.get_args())96)9798datasets = datasets.filter(99lambda x: x["mel_lengths"] > self.mel_length_threshold100)101102if allow_cache:103datasets = datasets.cache()104105if is_shuffle:106datasets = datasets.shuffle(107self.get_len_dataset(),108reshuffle_each_iteration=reshuffle_each_iteration,109)110111# define padded shapes112padded_shapes = {113"utt_ids": [],114"mels": [None, 80],115"mel_lengths": [],116}117118datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)119datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)120return datasets121122def get_len_dataset(self):123return len(self.utt_ids)124125def __name__(self):126return "MelDataset"127128129