Path: blob/master/tensorflow_tts/processor/thorsten.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Perform preprocessing and raw feature extraction for LJSpeech dataset."""1516import os17import re1819import numpy as np20import soundfile as sf21from dataclasses import dataclass22from tensorflow_tts.processor import BaseProcessor23from tensorflow_tts.utils import cleaners24from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME2526_pad = "pad"27_eos = "eos"28_punctuation = "!'(),.? "29_special = "-"30_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"3132# Export all symbols:33THORSTEN_SYMBOLS = (34[_pad] + list(_special) + list(_punctuation) + list(_letters) + [_eos]35)3637# Regular expression matching text enclosed in curly braces:38_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")394041@dataclass42class ThorstenProcessor(BaseProcessor):43"""Thorsten processor."""4445cleaner_names: str = "german_cleaners"46positions = {47"wave_file": 0,48"text_norm": 1,49}50train_f_name: str = "metadata.csv"5152def create_items(self):53if self.data_dir:54with open(55os.path.join(self.data_dir, self.train_f_name), encoding="utf-8"56) as f:57self.items = [self.split_line(self.data_dir, line, "|") for line in f]5859def split_line(self, data_dir, line, split):60parts = line.strip().split(split)61wave_file = parts[self.positions["wave_file"]]62text_norm = parts[self.positions["text_norm"]]63wav_path = os.path.join(data_dir, "wavs", f"{wave_file}.wav")64speaker_name = "thorsten"65return text_norm, wav_path, speaker_name6667def setup_eos_token(self):68return _eos6970def save_pretrained(self, saved_path):71os.makedirs(saved_path, exist_ok=True)72self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})7374def get_one_sample(self, item):75text, wav_path, speaker_name = item7677# normalize audio signal to be [-1, 1], soundfile already norm.78audio, rate = sf.read(wav_path)79audio = audio.astype(np.float32)8081# convert text to ids82text_ids = np.asarray(self.text_to_sequence(text), np.int32)8384sample = {85"raw_text": text,86"text_ids": text_ids,87"audio": audio,88"utt_id": os.path.split(wav_path)[-1].split(".")[0],89"speaker_name": speaker_name,90"rate": rate,91}9293return sample9495def text_to_sequence(self, text):96sequence = []97# Check for curly braces and treat their contents as ARPAbet:98while len(text):99m = _curly_re.match(text)100if not m:101sequence += self._symbols_to_sequence(102self._clean_text(text, [self.cleaner_names])103)104break105sequence += self._symbols_to_sequence(106self._clean_text(m.group(1), [self.cleaner_names])107)108sequence += self._arpabet_to_sequence(m.group(2))109text = m.group(3)110111# add eos tokens112sequence += [self.eos_id]113return sequence114115def _clean_text(self, text, cleaner_names):116for name in cleaner_names:117cleaner = getattr(cleaners, name)118if not cleaner:119raise Exception("Unknown cleaner: %s" % name)120text = cleaner(text)121return text122123def _symbols_to_sequence(self, symbols):124return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)]125126def _arpabet_to_sequence(self, text):127return self._symbols_to_sequence(["@" + s for s in text.split()])128129def _should_keep_symbol(self, s):130return s in self.symbol_to_id and s != "_" and s != "~"131132133