Path: blob/master/tensorflow_tts/processor/synpaflex.py
1558 views
# -*- coding: utf-8 -*-1# Copyright 2020 TensorFlowTTS Team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Perform preprocessing and raw feature extraction for SynPaFlex dataset."""1516import os17import re1819import numpy as np20import soundfile as sf21from dataclasses import dataclass22from tensorflow_tts.processor import BaseProcessor23from tensorflow_tts.utils import cleaners2425_pad = "pad"26_eos = "eos"27_punctuation = "!/\'(),-.:;? "28_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzéèàùâêîôûçäëïöüÿœæ"2930# Export all symbols:31SYNPAFLEX_SYMBOLS = (32[_pad] + list(_punctuation) + list(_letters) + [_eos]33)3435# Regular expression matching text enclosed in curly braces:36_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")373839@dataclass40class SynpaflexProcessor(BaseProcessor):41"""SynPaFlex processor."""4243cleaner_names: str = "basic_cleaners"44positions = {45"wave_file": 0,46"text": 1,47"text_norm": 248}49train_f_name: str = "synpaflex.txt"5051def create_items(self):52if self.data_dir:53with open(54os.path.join(self.data_dir, self.train_f_name), encoding="utf-8"55) as f:56self.items = [self.split_line(self.data_dir, line, "|") for line in f]5758def split_line(self, data_dir, line, split):59parts = line.strip().split(split)60wave_file = parts[self.positions["wave_file"]]61text = parts[self.positions["text"]]62wav_path = os.path.join(data_dir, "wavs", f"{wave_file}.wav")63speaker_name = "synpaflex"64return text, wav_path, speaker_name6566def setup_eos_token(self):67return _eos6869def get_one_sample(self, item):70text, wav_path, speaker_name = item7172# normalize audio signal to be [-1, 1], soundfile already norm.73audio, rate = sf.read(wav_path)74audio = audio.astype(np.float32)7576# convert text to ids77text_ids = np.asarray(self.text_to_sequence(text), np.int32)7879sample = {80"raw_text": text,81"text_ids": text_ids,82"audio": audio,83"utt_id": os.path.split(wav_path)[-1].split(".")[0],84"speaker_name": speaker_name,85"rate": rate,86}8788return sample8990def text_to_sequence(self, text):91sequence = []92# Check for curly braces and treat their contents as ARPAbet:93while len(text):94m = _curly_re.match(text)95if not m:96sequence += self._symbols_to_sequence(97self._clean_text(text, [self.cleaner_names])98)99break100sequence += self._symbols_to_sequence(101self._clean_text(m.group(1), [self.cleaner_names])102)103sequence += self._arpabet_to_sequence(m.group(2))104text = m.group(3)105106# add eos tokens107sequence += [self.eos_id]108return sequence109110def _clean_text(self, text, cleaner_names):111for name in cleaner_names:112cleaner = getattr(cleaners, name)113if not cleaner:114raise Exception("Unknown cleaner: %s" % name)115text = cleaner(text)116return text117118def _symbols_to_sequence(self, symbols):119return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)]120121def _sequence_to_symbols(self, sequence):122return [self.id_to_symbol[s] for s in sequence]123124def _arpabet_to_sequence(self, text):125return self._symbols_to_sequence(["@" + s for s in text.split()])126127def _should_keep_symbol(self, s):128return s in self.symbol_to_id and s != "_" and s != "~"129130def save_pretrained(self, saved_path):131os.makedirs(saved_path, exist_ok=True)132self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})133134135