Path: blob/master/examples/fastspeech2/fastspeech2_dataset.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Dataset modules."""1516import itertools17import logging18import os19import random2021import numpy as np22import tensorflow as tf2324from tensorflow_tts.datasets.abstract_dataset import AbstractDataset25from tensorflow_tts.utils import find_files262728def average_by_duration(x, durs):29mel_len = durs.sum()30durs_cum = np.cumsum(np.pad(durs, (1, 0)))3132# calculate charactor f0/energy33x_char = np.zeros((durs.shape[0],), dtype=np.float32)34for idx, start, end in zip(range(mel_len), durs_cum[:-1], durs_cum[1:]):35values = x[start:end][np.where(x[start:end] != 0.0)[0]]36x_char[idx] = np.mean(values) if len(values) > 0 else 0.0 # np.mean([]) = nan.3738return x_char.astype(np.float32)394041def tf_average_by_duration(x, durs):42outs = tf.numpy_function(average_by_duration, [x, durs], tf.float32)43return outs444546class CharactorDurationF0EnergyMelDataset(AbstractDataset):47"""Tensorflow Charactor Duration F0 Energy Mel dataset."""4849def __init__(50self,51root_dir,52charactor_query="*-ids.npy",53mel_query="*-norm-feats.npy",54duration_query="*-durations.npy",55f0_query="*-raw-f0.npy",56energy_query="*-raw-energy.npy",57f0_stat="./dump/stats_f0.npy",58energy_stat="./dump/stats_energy.npy",59charactor_load_fn=np.load,60mel_load_fn=np.load,61duration_load_fn=np.load,62f0_load_fn=np.load,63energy_load_fn=np.load,64mel_length_threshold=0,65):66"""Initialize dataset.6768Args:69root_dir (str): Root directory including dumped files.70charactor_query (str): Query to find charactor files in root_dir.71mel_query (str): Query to find feature files in root_dir.72duration_query (str): Query to find duration files in root_dir.73f0_query (str): Query to find f0 files in root_dir.74energy_query (str): Query to find energy files in root_dir.75f0_stat (str): str path of f0_stat.76energy_stat (str): str path of energy_stat.77charactor_load_fn (func): Function to load charactor file.78mel_load_fn (func): Function to load feature file.79duration_load_fn (func): Function to load duration file.80f0_load_fn (func): Function to load f0 file.81energy_load_fn (func): Function to load energy file.82mel_length_threshold (int): Threshold to remove short feature files.8384"""85# find all of charactor and mel files.86charactor_files = sorted(find_files(root_dir, charactor_query))87mel_files = sorted(find_files(root_dir, mel_query))88duration_files = sorted(find_files(root_dir, duration_query))89f0_files = sorted(find_files(root_dir, f0_query))90energy_files = sorted(find_files(root_dir, energy_query))9192# assert the number of files93assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}."94assert (95len(mel_files)96== len(charactor_files)97== len(duration_files)98== len(f0_files)99== len(energy_files)100), f"Number of charactor, mel, duration, f0 and energy files are different"101102if ".npy" in charactor_query:103suffix = charactor_query[1:]104utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]105106# set global params107self.utt_ids = utt_ids108self.mel_files = mel_files109self.charactor_files = charactor_files110self.duration_files = duration_files111self.f0_files = f0_files112self.energy_files = energy_files113self.mel_load_fn = mel_load_fn114self.charactor_load_fn = charactor_load_fn115self.duration_load_fn = duration_load_fn116self.f0_load_fn = f0_load_fn117self.energy_load_fn = energy_load_fn118self.mel_length_threshold = mel_length_threshold119120self.f0_stat = np.load(f0_stat)121self.energy_stat = np.load(energy_stat)122123def get_args(self):124return [self.utt_ids]125126def _norm_mean_std(self, x, mean, std):127zero_idxs = np.where(x == 0.0)[0]128x = (x - mean) / std129x[zero_idxs] = 0.0130return x131132def _norm_mean_std_tf(self, x, mean, std):133x = tf.numpy_function(self._norm_mean_std, [x, mean, std], tf.float32)134return x135136def generator(self, utt_ids):137for i, utt_id in enumerate(utt_ids):138mel_file = self.mel_files[i]139charactor_file = self.charactor_files[i]140duration_file = self.duration_files[i]141f0_file = self.f0_files[i]142energy_file = self.energy_files[i]143144items = {145"utt_ids": utt_id,146"mel_files": mel_file,147"charactor_files": charactor_file,148"duration_files": duration_file,149"f0_files": f0_file,150"energy_files": energy_file,151}152153yield items154155@tf.function156def _load_data(self, items):157mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)158charactor = tf.numpy_function(np.load, [items["charactor_files"]], tf.int32)159duration = tf.numpy_function(np.load, [items["duration_files"]], tf.int32)160f0 = tf.numpy_function(np.load, [items["f0_files"]], tf.float32)161energy = tf.numpy_function(np.load, [items["energy_files"]], tf.float32)162163f0 = self._norm_mean_std_tf(f0, self.f0_stat[0], self.f0_stat[1])164energy = self._norm_mean_std_tf(165energy, self.energy_stat[0], self.energy_stat[1]166)167168# calculate charactor f0/energy169f0 = tf_average_by_duration(f0, duration)170energy = tf_average_by_duration(energy, duration)171172items = {173"utt_ids": items["utt_ids"],174"input_ids": charactor,175"speaker_ids": 0,176"duration_gts": duration,177"f0_gts": f0,178"energy_gts": energy,179"mel_gts": mel,180"mel_lengths": len(mel),181}182183return items184185def create(186self,187allow_cache=False,188batch_size=1,189is_shuffle=False,190map_fn=None,191reshuffle_each_iteration=True,192):193"""Create tf.dataset function."""194output_types = self.get_output_dtypes()195datasets = tf.data.Dataset.from_generator(196self.generator, output_types=output_types, args=(self.get_args())197)198199# load data200datasets = datasets.map(201lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE202)203204datasets = datasets.filter(205lambda x: x["mel_lengths"] > self.mel_length_threshold206)207208if allow_cache:209datasets = datasets.cache()210211if is_shuffle:212datasets = datasets.shuffle(213self.get_len_dataset(),214reshuffle_each_iteration=reshuffle_each_iteration,215)216217# define padded shapes218padded_shapes = {219"utt_ids": [],220"input_ids": [None],221"speaker_ids": [],222"duration_gts": [None],223"f0_gts": [None],224"energy_gts": [None],225"mel_gts": [None, None],226"mel_lengths": [],227}228229datasets = datasets.padded_batch(230batch_size, padded_shapes=padded_shapes, drop_remainder=True231)232datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)233return datasets234235def get_output_dtypes(self):236output_types = {237"utt_ids": tf.string,238"mel_files": tf.string,239"charactor_files": tf.string,240"duration_files": tf.string,241"f0_files": tf.string,242"energy_files": tf.string,243}244return output_types245246def get_len_dataset(self):247return len(self.utt_ids)248249def __name__(self):250return "CharactorDurationF0EnergyMelDataset"251252253