Path: blob/master/examples/fastspeech/fastspeech_dataset.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Dataset modules."""1516import itertools17import logging18import os19import random2021import numpy as np22import tensorflow as tf2324from tensorflow_tts.datasets.abstract_dataset import AbstractDataset25from tensorflow_tts.utils import find_files262728class CharactorDurationMelDataset(AbstractDataset):29"""Tensorflow Charactor Mel dataset."""3031def __init__(32self,33root_dir,34charactor_query="*-ids.npy",35mel_query="*-norm-feats.npy",36duration_query="*-durations.npy",37charactor_load_fn=np.load,38mel_load_fn=np.load,39duration_load_fn=np.load,40mel_length_threshold=0,41):42"""Initialize dataset.4344Args:45root_dir (str): Root directory including dumped files.46charactor_query (str): Query to find charactor files in root_dir.47mel_query (str): Query to find feature files in root_dir.48duration_query (str): Query to find duration files in root_dir.49charactor_load_fn (func): Function to load charactor file.50mel_load_fn (func): Function to load feature file.51duration_load_fn (func): Function to load duration file.52mel_length_threshold (int): Threshold to remove short feature files.53return_utt_id (bool): Whether to return the utterance id with arrays.5455"""56# find all of charactor and mel files.57charactor_files = sorted(find_files(root_dir, charactor_query))58mel_files = sorted(find_files(root_dir, mel_query))59duration_files = sorted(find_files(root_dir, duration_query))6061# assert the number of files62assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}."63assert (64len(mel_files) == len(charactor_files) == len(duration_files)65), f"Number of charactor, mel and duration files are different \66({len(mel_files)} vs {len(charactor_files)} vs {len(duration_files)})."6768if ".npy" in charactor_query:69suffix = charactor_query[1:]70utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]7172# set global params73self.utt_ids = utt_ids74self.mel_files = mel_files75self.charactor_files = charactor_files76self.duration_files = duration_files77self.mel_load_fn = mel_load_fn78self.charactor_load_fn = charactor_load_fn79self.duration_load_fn = duration_load_fn80self.mel_length_threshold = mel_length_threshold8182def get_args(self):83return [self.utt_ids]8485def generator(self, utt_ids):86for i, utt_id in enumerate(utt_ids):87mel_file = self.mel_files[i]88charactor_file = self.charactor_files[i]89duration_file = self.duration_files[i]9091items = {92"utt_ids": utt_id,93"mel_files": mel_file,94"charactor_files": charactor_file,95"duration_files": duration_file,96}9798yield items99100@tf.function101def _load_data(self, items):102mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)103charactor = tf.numpy_function(np.load, [items["charactor_files"]], tf.int32)104duration = tf.numpy_function(np.load, [items["duration_files"]], tf.int32)105106items = {107"utt_ids": items["utt_ids"],108"input_ids": charactor,109"speaker_ids": 0,110"duration_gts": duration,111"mel_gts": mel,112"mel_lengths": len(mel),113}114115return items116117def create(118self,119allow_cache=False,120batch_size=1,121is_shuffle=False,122map_fn=None,123reshuffle_each_iteration=True,124):125"""Create tf.dataset function."""126output_types = self.get_output_dtypes()127datasets = tf.data.Dataset.from_generator(128self.generator, output_types=output_types, args=(self.get_args())129)130131# load data132datasets = datasets.map(133lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE134)135136datasets = datasets.filter(137lambda x: x["mel_lengths"] > self.mel_length_threshold138)139140if allow_cache:141datasets = datasets.cache()142143if is_shuffle:144datasets = datasets.shuffle(145self.get_len_dataset(),146reshuffle_each_iteration=reshuffle_each_iteration,147)148149# define padded_shapes150padded_shapes = {151"utt_ids": [],152"input_ids": [None],153"speaker_ids": [],154"duration_gts": [None],155"mel_gts": [None, None],156"mel_lengths": [],157}158159datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)160datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)161return datasets162163def get_output_dtypes(self):164output_types = {165"utt_ids": tf.string,166"mel_files": tf.string,167"charactor_files": tf.string,168"duration_files": tf.string,169}170return output_types171172def get_len_dataset(self):173return len(self.utt_ids)174175def __name__(self):176return "CharactorDurationMelDataset"177178179class CharactorDataset(AbstractDataset):180"""Tensorflow Charactor dataset."""181182def __init__(183self, root_dir, charactor_query="*-ids.npy", charactor_load_fn=np.load,184):185"""Initialize dataset.186187Args:188root_dir (str): Root directory including dumped files.189charactor_query (str): Query to find charactor files in root_dir.190charactor_load_fn (func): Function to load charactor file.191return_utt_id (bool): Whether to return the utterance id with arrays.192193"""194# find all of charactor and mel files.195charactor_files = sorted(find_files(root_dir, charactor_query))196197# assert the number of files198assert (199len(charactor_files) != 0200), f"Not found any char or duration files in ${root_dir}."201if ".npy" in charactor_query:202suffix = charactor_query[1:]203utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]204205# set global params206self.utt_ids = utt_ids207self.charactor_files = charactor_files208self.charactor_load_fn = charactor_load_fn209210def get_args(self):211return [self.utt_ids]212213def generator(self, utt_ids):214for i, utt_id in enumerate(utt_ids):215charactor_file = self.charactor_files[i]216charactor = self.charactor_load_fn(charactor_file)217218items = {"utt_ids": utt_id, "input_ids": charactor}219220yield items221222def create(223self,224allow_cache=False,225batch_size=1,226is_shuffle=False,227map_fn=None,228reshuffle_each_iteration=True,229):230"""Create tf.dataset function."""231output_types = self.get_output_dtypes()232datasets = tf.data.Dataset.from_generator(233self.generator, output_types=output_types, args=(self.get_args())234)235236if allow_cache:237datasets = datasets.cache()238239if is_shuffle:240datasets = datasets.shuffle(241self.get_len_dataset(),242reshuffle_each_iteration=reshuffle_each_iteration,243)244245# define padded shapes246padded_shapes = {"utt_ids": [], "input_ids": [None]}247248datasets = datasets.padded_batch(249batch_size, padded_shapes=padded_shapes, drop_remainder=True250)251datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)252return datasets253254def get_output_dtypes(self):255output_types = {"utt_ids": tf.string, "input_ids": tf.int32}256return output_types257258def get_len_dataset(self):259return len(self.utt_ids)260261def __name__(self):262return "CharactorDataset"263264265