Path: blob/master/examples/fastspeech2_libritts/fastspeech2_dataset.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Dataset modules."""1516import os17import numpy as np18import tensorflow as tf1920from tensorflow_tts.datasets.abstract_dataset import AbstractDataset21from tensorflow_tts.utils import find_files222324def average_by_duration(x, durs):25mel_len = durs.sum()26durs_cum = np.cumsum(np.pad(durs, (1, 0)))2728# calculate charactor f0/energy29x_char = np.zeros((durs.shape[0],), dtype=np.float32)30for idx, start, end in zip(range(mel_len), durs_cum[:-1], durs_cum[1:]):31values = x[start:end][np.where(x[start:end] != 0.0)[0]]32x_char[idx] = np.mean(values) if len(values) > 0 else 0.0 # np.mean([]) = nan.3334return x_char.astype(np.float32)353637def tf_average_by_duration(x, durs):38outs = tf.numpy_function(average_by_duration, [x, durs], tf.float32)39return outs404142class CharactorDurationF0EnergyMelDataset(AbstractDataset):43"""Tensorflow Charactor Duration F0 Energy Mel dataset."""4445def __init__(46self,47root_dir,48charactor_query="*-ids.npy",49mel_query="*-norm-feats.npy",50duration_query="*-durations.npy",51f0_query="*-raw-f0.npy",52energy_query="*-raw-energy.npy",53f0_stat="./dump/stats_f0.npy",54energy_stat="./dump/stats_energy.npy",55charactor_load_fn=np.load,56mel_load_fn=np.load,57duration_load_fn=np.load,58f0_load_fn=np.load,59energy_load_fn=np.load,60mel_length_threshold=0,61speakers_map=None62):63"""Initialize dataset.6465Args:66root_dir (str): Root directory including dumped files.67charactor_query (str): Query to find charactor files in root_dir.68mel_query (str): Query to find feature files in root_dir.69duration_query (str): Query to find duration files in root_dir.70f0_query (str): Query to find f0 files in root_dir.71energy_query (str): Query to find energy files in root_dir.72f0_stat (str): str path of f0_stat.73energy_stat (str): str path of energy_stat.74charactor_load_fn (func): Function to load charactor file.75mel_load_fn (func): Function to load feature file.76duration_load_fn (func): Function to load duration file.77f0_load_fn (func): Function to load f0 file.78energy_load_fn (func): Function to load energy file.79mel_length_threshold (int): Threshold to remove short feature files.80speakers_map (dict): Speakers map generated in dataset preprocessing8182"""83# find all of charactor and mel files.84charactor_files = sorted(find_files(root_dir, charactor_query))85mel_files = sorted(find_files(root_dir, mel_query))86duration_files = sorted(find_files(root_dir, duration_query))87f0_files = sorted(find_files(root_dir, f0_query))88energy_files = sorted(find_files(root_dir, energy_query))8990# assert the number of files91assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}."92assert (93len(mel_files)94== len(charactor_files)95== len(duration_files)96== len(f0_files)97== len(energy_files)98), f"Number of charactor, mel, duration, f0 and energy files are different"99100assert speakers_map != None, f"No speakers map found. Did you set --dataset_mapping?"101102if ".npy" in charactor_query:103suffix = charactor_query[1:]104utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]105106# set global params107self.utt_ids = utt_ids108self.mel_files = mel_files109self.charactor_files = charactor_files110self.duration_files = duration_files111self.f0_files = f0_files112self.energy_files = energy_files113self.mel_load_fn = mel_load_fn114self.charactor_load_fn = charactor_load_fn115self.duration_load_fn = duration_load_fn116self.f0_load_fn = f0_load_fn117self.energy_load_fn = energy_load_fn118self.mel_length_threshold = mel_length_threshold119self.speakers_map = speakers_map120self.speakers = [self.speakers_map[i.split("_")[0]] for i in self.utt_ids]121print("Speaker: utt_id", list(zip(self.speakers, self.utt_ids)))122self.f0_stat = np.load(f0_stat)123self.energy_stat = np.load(energy_stat)124125def get_args(self):126return [self.utt_ids]127128def _norm_mean_std(self, x, mean, std):129zero_idxs = np.where(x == 0.0)[0]130x = (x - mean) / std131x[zero_idxs] = 0.0132return x133134def _norm_mean_std_tf(self, x, mean, std):135x = tf.numpy_function(self._norm_mean_std, [x, mean, std], tf.float32)136return x137138def generator(self, utt_ids):139for i, utt_id in enumerate(utt_ids):140mel_file = self.mel_files[i]141charactor_file = self.charactor_files[i]142duration_file = self.duration_files[i]143f0_file = self.f0_files[i]144energy_file = self.energy_files[i]145speaker_id = self.speakers[i]146147items = {148"utt_ids": utt_id,149"mel_files": mel_file,150"charactor_files": charactor_file,151"duration_files": duration_file,152"f0_files": f0_file,153"energy_files": energy_file,154"speaker_ids": speaker_id,155}156157yield items158159@tf.function160def _load_data(self, items):161mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)162charactor = tf.numpy_function(np.load, [items["charactor_files"]], tf.int32)163duration = tf.numpy_function(np.load, [items["duration_files"]], tf.int32)164f0 = tf.numpy_function(np.load, [items["f0_files"]], tf.float32)165energy = tf.numpy_function(np.load, [items["energy_files"]], tf.float32)166167f0 = self._norm_mean_std_tf(f0, self.f0_stat[0], self.f0_stat[1])168energy = self._norm_mean_std_tf(169energy, self.energy_stat[0], self.energy_stat[1]170)171172# calculate charactor f0/energy173f0 = tf_average_by_duration(f0, duration)174energy = tf_average_by_duration(energy, duration)175176items = {177"utt_ids": items["utt_ids"],178"input_ids": charactor,179"speaker_ids": items["speaker_ids"],180"duration_gts": duration,181"f0_gts": f0,182"energy_gts": energy,183"mel_gts": mel,184"mel_lengths": len(mel),185}186187return items188189def create(190self,191allow_cache=False,192batch_size=1,193is_shuffle=False,194map_fn=None,195reshuffle_each_iteration=True,196):197"""Create tf.dataset function."""198output_types = self.get_output_dtypes()199datasets = tf.data.Dataset.from_generator(200self.generator, output_types=output_types, args=(self.get_args())201)202203# load data204datasets = datasets.map(205lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE206)207208datasets = datasets.filter(209lambda x: x["mel_lengths"] > self.mel_length_threshold210)211212if allow_cache:213datasets = datasets.cache()214215if is_shuffle:216datasets = datasets.shuffle(217self.get_len_dataset(),218reshuffle_each_iteration=reshuffle_each_iteration,219)220221# define padded shapes222padded_shapes = {223"utt_ids": [],224"input_ids": [None],225"speaker_ids": [],226"duration_gts": [None],227"f0_gts": [None],228"energy_gts": [None],229"mel_gts": [None, None],230"mel_lengths": [],231}232233datasets = datasets.padded_batch(234batch_size, padded_shapes=padded_shapes, drop_remainder=True235)236datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)237return datasets238239def get_output_dtypes(self):240output_types = {241"utt_ids": tf.string,242"mel_files": tf.string,243"charactor_files": tf.string,244"duration_files": tf.string,245"f0_files": tf.string,246"energy_files": tf.string,247"speaker_ids": tf.int32,248}249return output_types250251def get_len_dataset(self):252return len(self.utt_ids)253254def __name__(self):255return "CharactorDurationF0EnergyMelDataset"256257258