Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/processor/synpaflex.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Perform preprocessing and raw feature extraction for SynPaFlex dataset."""
16
17
import os
18
import re
19
20
import numpy as np
21
import soundfile as sf
22
from dataclasses import dataclass
23
from tensorflow_tts.processor import BaseProcessor
24
from tensorflow_tts.utils import cleaners
25
26
_pad = "pad"
27
_eos = "eos"
28
_punctuation = "!/\'(),-.:;? "
29
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzéèàùâêîôûçäëïöüÿœæ"
30
31
# Export all symbols:
32
SYNPAFLEX_SYMBOLS = (
33
[_pad] + list(_punctuation) + list(_letters) + [_eos]
34
)
35
36
# Regular expression matching text enclosed in curly braces:
37
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
38
39
40
@dataclass
41
class SynpaflexProcessor(BaseProcessor):
42
"""SynPaFlex processor."""
43
44
cleaner_names: str = "basic_cleaners"
45
positions = {
46
"wave_file": 0,
47
"text": 1,
48
"text_norm": 2
49
}
50
train_f_name: str = "synpaflex.txt"
51
52
def create_items(self):
53
if self.data_dir:
54
with open(
55
os.path.join(self.data_dir, self.train_f_name), encoding="utf-8"
56
) as f:
57
self.items = [self.split_line(self.data_dir, line, "|") for line in f]
58
59
def split_line(self, data_dir, line, split):
60
parts = line.strip().split(split)
61
wave_file = parts[self.positions["wave_file"]]
62
text = parts[self.positions["text"]]
63
wav_path = os.path.join(data_dir, "wavs", f"{wave_file}.wav")
64
speaker_name = "synpaflex"
65
return text, wav_path, speaker_name
66
67
def setup_eos_token(self):
68
return _eos
69
70
def get_one_sample(self, item):
71
text, wav_path, speaker_name = item
72
73
# normalize audio signal to be [-1, 1], soundfile already norm.
74
audio, rate = sf.read(wav_path)
75
audio = audio.astype(np.float32)
76
77
# convert text to ids
78
text_ids = np.asarray(self.text_to_sequence(text), np.int32)
79
80
sample = {
81
"raw_text": text,
82
"text_ids": text_ids,
83
"audio": audio,
84
"utt_id": os.path.split(wav_path)[-1].split(".")[0],
85
"speaker_name": speaker_name,
86
"rate": rate,
87
}
88
89
return sample
90
91
def text_to_sequence(self, text):
92
sequence = []
93
# Check for curly braces and treat their contents as ARPAbet:
94
while len(text):
95
m = _curly_re.match(text)
96
if not m:
97
sequence += self._symbols_to_sequence(
98
self._clean_text(text, [self.cleaner_names])
99
)
100
break
101
sequence += self._symbols_to_sequence(
102
self._clean_text(m.group(1), [self.cleaner_names])
103
)
104
sequence += self._arpabet_to_sequence(m.group(2))
105
text = m.group(3)
106
107
# add eos tokens
108
sequence += [self.eos_id]
109
return sequence
110
111
def _clean_text(self, text, cleaner_names):
112
for name in cleaner_names:
113
cleaner = getattr(cleaners, name)
114
if not cleaner:
115
raise Exception("Unknown cleaner: %s" % name)
116
text = cleaner(text)
117
return text
118
119
def _symbols_to_sequence(self, symbols):
120
return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)]
121
122
def _sequence_to_symbols(self, sequence):
123
return [self.id_to_symbol[s] for s in sequence]
124
125
def _arpabet_to_sequence(self, text):
126
return self._symbols_to_sequence(["@" + s for s in text.split()])
127
128
def _should_keep_symbol(self, s):
129
return s in self.symbol_to_id and s != "_" and s != "~"
130
131
def save_pretrained(self, saved_path):
132
os.makedirs(saved_path, exist_ok=True)
133
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})
134
135