Path: blob/master/examples/tacotron2/tacotron_dataset.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 Minh Nguyen (@dathudeptrai)2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tacotron Related Dataset modules."""1516import itertools17import logging18import os19import random2021import numpy as np22import tensorflow as tf2324from tensorflow_tts.datasets.abstract_dataset import AbstractDataset25from tensorflow_tts.utils import find_files262728class CharactorMelDataset(AbstractDataset):29"""Tensorflow Charactor Mel dataset."""3031def __init__(32self,33dataset,34root_dir,35charactor_query="*-ids.npy",36mel_query="*-norm-feats.npy",37align_query="",38charactor_load_fn=np.load,39mel_load_fn=np.load,40mel_length_threshold=0,41reduction_factor=1,42mel_pad_value=0.0,43char_pad_value=0,44ga_pad_value=-1.0,45g=0.2,46use_fixed_shapes=False,47):48"""Initialize dataset.4950Args:51root_dir (str): Root directory including dumped files.52charactor_query (str): Query to find charactor files in root_dir.53mel_query (str): Query to find feature files in root_dir.54charactor_load_fn (func): Function to load charactor file.55align_query (str): Query to find FAL files in root_dir. If empty, we use stock guided attention loss56mel_load_fn (func): Function to load feature file.57mel_length_threshold (int): Threshold to remove short feature files.58reduction_factor (int): Reduction factor on Tacotron-2 paper.59mel_pad_value (float): Padding value for mel-spectrogram.60char_pad_value (int): Padding value for charactor.61ga_pad_value (float): Padding value for guided attention.62g (float): G value for guided attention.63use_fixed_shapes (bool): Use fixed shape for mel targets or not.64max_char_length (int): maximum charactor length if use_fixed_shapes=True.65max_mel_length (int): maximum mel length if use_fixed_shapes=True6667"""68# find all of charactor and mel files.69charactor_files = sorted(find_files(root_dir, charactor_query))70mel_files = sorted(find_files(root_dir, mel_query))7172mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]73char_lengths = [charactor_load_fn(f).shape[0] for f in charactor_files]7475# assert the number of files76assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}."77assert (78len(mel_files) == len(charactor_files) == len(mel_lengths)79), f"Number of charactor, mel and duration files are different \80({len(mel_files)} vs {len(charactor_files)} vs {len(mel_lengths)})."8182self.align_files = []8384if len(align_query) > 1:85align_files = sorted(find_files(root_dir, align_query))86assert len(align_files) == len(87mel_files88), f"Number of align files ({len(align_files)}) and mel files ({len(mel_files)}) are different"89logging.info("Using FAL loss")90self.align_files = align_files91else:92logging.info("Using guided attention loss")9394if ".npy" in charactor_query:95suffix = charactor_query[1:]96utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]9798# set global params99self.utt_ids = utt_ids100self.mel_files = mel_files101self.charactor_files = charactor_files102self.mel_load_fn = mel_load_fn103self.charactor_load_fn = charactor_load_fn104self.mel_lengths = mel_lengths105self.char_lengths = char_lengths106self.reduction_factor = reduction_factor107self.mel_length_threshold = mel_length_threshold108self.mel_pad_value = mel_pad_value109self.char_pad_value = char_pad_value110self.ga_pad_value = ga_pad_value111self.g = g112self.use_fixed_shapes = use_fixed_shapes113self.max_char_length = np.max(char_lengths)114115if np.max(mel_lengths) % self.reduction_factor == 0:116self.max_mel_length = np.max(mel_lengths)117else:118self.max_mel_length = (119np.max(mel_lengths)120+ self.reduction_factor121- np.max(mel_lengths) % self.reduction_factor122)123124def get_args(self):125return [self.utt_ids]126127def generator(self, utt_ids):128for i, utt_id in enumerate(utt_ids):129mel_file = self.mel_files[i]130charactor_file = self.charactor_files[i]131align_file = self.align_files[i] if len(self.align_files) > 1 else ""132133items = {134"utt_ids": utt_id,135"mel_files": mel_file,136"charactor_files": charactor_file,137"align_files": align_file,138}139140yield items141142@tf.function143def _load_data(self, items):144mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)145charactor = tf.numpy_function(np.load, [items["charactor_files"]], tf.int32)146g_att = (147tf.numpy_function(np.load, [items["align_files"]], tf.float32)148if len(self.align_files) > 1149else None150)151152mel_length = len(mel)153char_length = len(charactor)154# padding mel to make its length is multiple of reduction factor.155real_mel_length = mel_length156remainder = mel_length % self.reduction_factor157if remainder != 0:158new_mel_length = mel_length + self.reduction_factor - remainder159mel = tf.pad(160mel,161[[0, new_mel_length - mel_length], [0, 0]],162constant_values=self.mel_pad_value,163)164mel_length = new_mel_length165166items = {167"utt_ids": items["utt_ids"],168"input_ids": charactor,169"input_lengths": char_length,170"speaker_ids": 0,171"mel_gts": mel,172"mel_lengths": mel_length,173"real_mel_lengths": real_mel_length,174"g_attentions": g_att,175}176177return items178179def _guided_attention(self, items):180"""Guided attention. Refer to page 3 on the paper (https://arxiv.org/abs/1710.08969)."""181items = items.copy()182mel_len = items["mel_lengths"] // self.reduction_factor183char_len = items["input_lengths"]184xv, yv = tf.meshgrid(tf.range(char_len), tf.range(mel_len), indexing="ij")185f32_matrix = tf.cast(yv / mel_len - xv / char_len, tf.float32)186items["g_attentions"] = 1.0 - tf.math.exp(187-(f32_matrix ** 2) / (2 * self.g ** 2)188)189return items190191def create(192self,193allow_cache=False,194batch_size=1,195is_shuffle=False,196map_fn=None,197reshuffle_each_iteration=True,198drop_remainder=True,199):200"""Create tf.dataset function."""201output_types = self.get_output_dtypes()202datasets = tf.data.Dataset.from_generator(203self.generator, output_types=output_types, args=(self.get_args())204)205206# load data207datasets = datasets.map(208lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE209)210211# calculate guided attention212if len(self.align_files) < 1:213datasets = datasets.map(214lambda items: self._guided_attention(items),215tf.data.experimental.AUTOTUNE,216)217218datasets = datasets.filter(219lambda x: x["mel_lengths"] > self.mel_length_threshold220)221222if allow_cache:223datasets = datasets.cache()224225if is_shuffle:226datasets = datasets.shuffle(227self.get_len_dataset(),228reshuffle_each_iteration=reshuffle_each_iteration,229)230231# define padding value.232padding_values = {233"utt_ids": " ",234"input_ids": self.char_pad_value,235"input_lengths": 0,236"speaker_ids": 0,237"mel_gts": self.mel_pad_value,238"mel_lengths": 0,239"real_mel_lengths": 0,240"g_attentions": self.ga_pad_value,241}242243# define padded shapes.244padded_shapes = {245"utt_ids": [],246"input_ids": [None]247if self.use_fixed_shapes is False248else [self.max_char_length],249"input_lengths": [],250"speaker_ids": [],251"mel_gts": [None, 80]252if self.use_fixed_shapes is False253else [self.max_mel_length, 80],254"mel_lengths": [],255"real_mel_lengths": [],256"g_attentions": [None, None]257if self.use_fixed_shapes is False258else [self.max_char_length, self.max_mel_length // self.reduction_factor],259}260261datasets = datasets.padded_batch(262batch_size,263padded_shapes=padded_shapes,264padding_values=padding_values,265drop_remainder=drop_remainder,266)267datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)268return datasets269270def get_output_dtypes(self):271output_types = {272"utt_ids": tf.string,273"mel_files": tf.string,274"charactor_files": tf.string,275"align_files": tf.string,276}277return output_types278279def get_len_dataset(self):280return len(self.utt_ids)281282def __name__(self):283return "CharactorMelDataset"284285286