Path: blob/master/tensorflow_tts/processor/ljspeech.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Perform preprocessing and raw feature extraction for LJSpeech dataset."""1516import os17import re1819import numpy as np20import soundfile as sf21from dataclasses import dataclass22from tensorflow_tts.processor import BaseProcessor23from tensorflow_tts.utils import cleaners24from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME2526valid_symbols = [27"AA",28"AA0",29"AA1",30"AA2",31"AE",32"AE0",33"AE1",34"AE2",35"AH",36"AH0",37"AH1",38"AH2",39"AO",40"AO0",41"AO1",42"AO2",43"AW",44"AW0",45"AW1",46"AW2",47"AY",48"AY0",49"AY1",50"AY2",51"B",52"CH",53"D",54"DH",55"EH",56"EH0",57"EH1",58"EH2",59"ER",60"ER0",61"ER1",62"ER2",63"EY",64"EY0",65"EY1",66"EY2",67"F",68"G",69"HH",70"IH",71"IH0",72"IH1",73"IH2",74"IY",75"IY0",76"IY1",77"IY2",78"JH",79"K",80"L",81"M",82"N",83"NG",84"OW",85"OW0",86"OW1",87"OW2",88"OY",89"OY0",90"OY1",91"OY2",92"P",93"R",94"S",95"SH",96"T",97"TH",98"UH",99"UH0",100"UH1",101"UH2",102"UW",103"UW0",104"UW1",105"UW2",106"V",107"W",108"Y",109"Z",110"ZH",111]112113_pad = "pad"114_eos = "eos"115_punctuation = "!'(),.:;? "116_special = "-"117_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"118119# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):120_arpabet = ["@" + s for s in valid_symbols]121122# Export all symbols:123LJSPEECH_SYMBOLS = (124[_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet + [_eos]125)126127# Regular expression matching text enclosed in curly braces:128_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")129130131@dataclass132class LJSpeechProcessor(BaseProcessor):133"""LJSpeech processor."""134135cleaner_names: str = "english_cleaners"136positions = {137"wave_file": 0,138"text": 1,139"text_norm": 2,140}141train_f_name: str = "metadata.csv"142143def create_items(self):144if self.data_dir:145with open(146os.path.join(self.data_dir, self.train_f_name), encoding="utf-8"147) as f:148self.items = [self.split_line(self.data_dir, line, "|") for line in f]149150def split_line(self, data_dir, line, split):151parts = line.strip().split(split)152wave_file = parts[self.positions["wave_file"]]153text_norm = parts[self.positions["text_norm"]]154wav_path = os.path.join(data_dir, "wavs", f"{wave_file}.wav")155speaker_name = "ljspeech"156return text_norm, wav_path, speaker_name157158def setup_eos_token(self):159return _eos160161def save_pretrained(self, saved_path):162os.makedirs(saved_path, exist_ok=True)163self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})164165def get_one_sample(self, item):166text, wav_path, speaker_name = item167168# normalize audio signal to be [-1, 1], soundfile already norm.169audio, rate = sf.read(wav_path)170audio = audio.astype(np.float32)171172# convert text to ids173text_ids = np.asarray(self.text_to_sequence(text), np.int32)174175sample = {176"raw_text": text,177"text_ids": text_ids,178"audio": audio,179"utt_id": os.path.split(wav_path)[-1].split(".")[0],180"speaker_name": speaker_name,181"rate": rate,182}183184return sample185186def text_to_sequence(self, text):187sequence = []188# Check for curly braces and treat their contents as ARPAbet:189while len(text):190m = _curly_re.match(text)191if not m:192sequence += self._symbols_to_sequence(193self._clean_text(text, [self.cleaner_names])194)195break196sequence += self._symbols_to_sequence(197self._clean_text(m.group(1), [self.cleaner_names])198)199sequence += self._arpabet_to_sequence(m.group(2))200text = m.group(3)201202# add eos tokens203sequence += [self.eos_id]204return sequence205206def _clean_text(self, text, cleaner_names):207for name in cleaner_names:208cleaner = getattr(cleaners, name)209if not cleaner:210raise Exception("Unknown cleaner: %s" % name)211text = cleaner(text)212return text213214def _symbols_to_sequence(self, symbols):215return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)]216217def _arpabet_to_sequence(self, text):218return self._symbols_to_sequence(["@" + s for s in text.split()])219220def _should_keep_symbol(self, s):221return s in self.symbol_to_id and s != "_" and s != "~"222223224