Path: blob/master/tensorflow_tts/processor/libritts.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Perform preprocessing and raw feature extraction for LibriTTS dataset."""1516import os17import re1819import numpy as np20import soundfile as sf21from dataclasses import dataclass2223from g2p_en import g2p as grapheme_to_phonem2425from tensorflow_tts.processor.base_processor import BaseProcessor26from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME2728g2p = grapheme_to_phonem.G2p()2930valid_symbols = g2p.phonemes31valid_symbols.append("SIL")32valid_symbols.append("END")3334_punctuation = "!'(),.:;? "35_arpabet = ["@" + s for s in valid_symbols]3637LIBRITTS_SYMBOLS = _arpabet + list(_punctuation)383940@dataclass41class LibriTTSProcessor(BaseProcessor):4243mode: str = "train"44train_f_name: str = "train.txt"45positions = {46"file": 0,47"text": 1,48"speaker_name": 2,49} # positions of file,text,speaker_name after split line50f_extension: str = ".wav"51cleaner_names: str = None5253def create_items(self):54with open(55os.path.join(self.data_dir, self.train_f_name), mode="r", encoding="utf-8"56) as f:57for line in f:58parts = line.strip().split(self.delimiter)59wav_path = os.path.join(self.data_dir, parts[self.positions["file"]])60wav_path = (61wav_path + self.f_extension62if wav_path[-len(self.f_extension) :] != self.f_extension63else wav_path64)65text = parts[self.positions["text"]]66speaker_name = parts[self.positions["speaker_name"]]67self.items.append([text, wav_path, speaker_name])6869def get_one_sample(self, item):70text, wav_path, speaker_name = item71audio, rate = sf.read(wav_path, dtype="float32")7273text_ids = np.asarray(self.text_to_sequence(text), np.int32)7475sample = {76"raw_text": text,77"text_ids": text_ids,78"audio": audio,79"utt_id": wav_path.split("/")[-1].split(".")[0],80"speaker_name": speaker_name,81"rate": rate,82}8384return sample8586def setup_eos_token(self):87return None # because we do not use this8889def save_pretrained(self, saved_path):90os.makedirs(saved_path, exist_ok=True)91self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})9293def text_to_sequence(self, text):94if (95self.mode == "train"96): # in train mode text should be already transformed to phonemes97return self.symbols_to_ids(self.clean_g2p(text.split(" ")))98else:99return self.inference_text_to_seq(text)100101def inference_text_to_seq(self, text: str):102return self.symbols_to_ids(self.text_to_ph(text))103104def symbols_to_ids(self, symbols_list: list):105return [self.symbol_to_id[s] for s in symbols_list]106107def text_to_ph(self, text: str):108return self.clean_g2p(g2p(text))109110def clean_g2p(self, g2p_text: list):111data = []112for i, txt in enumerate(g2p_text):113if i == len(g2p_text) - 1:114if txt != " " and txt != "SIL":115data.append("@" + txt)116else:117data.append(118"@END"119) # TODO try learning without end token and compare results120break121if txt != " ":122data.append("@" + txt)123return data124125126