Path: blob/master/tensorflow_tts/processor/kss.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Perform preprocessing and raw feature extraction for KSS dataset."""1516import os17import re1819import numpy as np20import soundfile as sf21from dataclasses import dataclass22from tensorflow_tts.processor import BaseProcessor23from tensorflow_tts.utils import cleaners24from tensorflow_tts.utils.korean import symbols as KSS_SYMBOLS25from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME2627# Regular expression matching text enclosed in curly braces:28_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")293031@dataclass32class KSSProcessor(BaseProcessor):33"""KSS processor."""3435cleaner_names: str = "korean_cleaners"36positions = {37"wave_file": 0,38"text_norm": 2,39}40train_f_name: str = "transcript.v.1.4.txt"4142def create_items(self):43if self.data_dir:44with open(45os.path.join(self.data_dir, self.train_f_name), encoding="utf-8"46) as f:47self.items = [self.split_line(self.data_dir, line, "|") for line in f]4849def split_line(self, data_dir, line, split):50parts = line.strip().split(split)51wave_file = parts[self.positions["wave_file"]]52text_norm = parts[self.positions["text_norm"]]53wav_path = os.path.join(data_dir, "kss", wave_file)54speaker_name = "kss"55return text_norm, wav_path, speaker_name5657def setup_eos_token(self):58return "eos"5960def save_pretrained(self, saved_path):61os.makedirs(saved_path, exist_ok=True)62self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})6364def get_one_sample(self, item):65text, wav_path, speaker_name = item6667# normalize audio signal to be [-1, 1], soundfile already norm.68audio, rate = sf.read(wav_path)69audio = audio.astype(np.float32)7071# convert text to ids72text_ids = np.asarray(self.text_to_sequence(text), np.int32)7374sample = {75"raw_text": text,76"text_ids": text_ids,77"audio": audio,78"utt_id": os.path.split(wav_path)[-1].split(".")[0],79"speaker_name": speaker_name,80"rate": rate,81}8283return sample8485def text_to_sequence(self, text):8687sequence = []88# Check for curly braces and treat their contents as ARPAbet:89while len(text):90m = _curly_re.match(text)91if not m:92sequence += self._symbols_to_sequence(93self._clean_text(text, [self.cleaner_names])94)95break96sequence += self._symbols_to_sequence(97self._clean_text(m.group(1), [self.cleaner_names])98)99sequence += self._arpabet_to_sequence(m.group(2))100text = m.group(3)101102# add eos tokens103sequence += [self.eos_id]104return sequence105106def _clean_text(self, text, cleaner_names):107for name in cleaner_names:108cleaner = getattr(cleaners, name)109if not cleaner:110raise Exception("Unknown cleaner: %s" % name)111text = cleaner(text)112return text113114def _symbols_to_sequence(self, symbols):115return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)]116117def _arpabet_to_sequence(self, text):118return self._symbols_to_sequence(["@" + s for s in text.split()])119120def _should_keep_symbol(self, s):121return s in self.symbol_to_id and s != "_" and s != "~"122123124