Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/processor/ljspeechu.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Perform preprocessing and raw feature extraction for LJSpeech Ultimate dataset."""
16
17
import os
18
import re
19
20
import numpy as np
21
import soundfile as sf
22
from dataclasses import dataclass
23
from tensorflow_tts.processor import BaseProcessor
24
from tensorflow_tts.utils import cleaners
25
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME
26
from g2p_en import G2p as grapheme_to_phn
27
28
valid_symbols = [
29
"AA",
30
"AA0",
31
"AA1",
32
"AA2",
33
"AE",
34
"AE0",
35
"AE1",
36
"AE2",
37
"AH",
38
"AH0",
39
"AH1",
40
"AH2",
41
"AO",
42
"AO0",
43
"AO1",
44
"AO2",
45
"AW",
46
"AW0",
47
"AW1",
48
"AW2",
49
"AY",
50
"AY0",
51
"AY1",
52
"AY2",
53
"B",
54
"CH",
55
"D",
56
"DH",
57
"EH",
58
"EH0",
59
"EH1",
60
"EH2",
61
"ER",
62
"ER0",
63
"ER1",
64
"ER2",
65
"EY",
66
"EY0",
67
"EY1",
68
"EY2",
69
"F",
70
"G",
71
"HH",
72
"IH",
73
"IH0",
74
"IH1",
75
"IH2",
76
"IY",
77
"IY0",
78
"IY1",
79
"IY2",
80
"JH",
81
"K",
82
"L",
83
"M",
84
"N",
85
"NG",
86
"OW",
87
"OW0",
88
"OW1",
89
"OW2",
90
"OY",
91
"OY0",
92
"OY1",
93
"OY2",
94
"P",
95
"R",
96
"S",
97
"SH",
98
"T",
99
"TH",
100
"UH",
101
"UH0",
102
"UH1",
103
"UH2",
104
"UW",
105
"UW0",
106
"UW1",
107
"UW2",
108
"V",
109
"W",
110
"Y",
111
"Z",
112
"ZH",
113
]
114
115
_pad = "pad"
116
_eos = "eos"
117
_punctuation = "!'(),.:;?" # Unlike LJSpeech, we do not use spaces since we are phoneme only and spaces lead to very bad attention performance with phonetic input.
118
_special = "-"
119
120
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
121
_arpabet = ["@" + s for s in valid_symbols]
122
123
# Export all symbols:
124
LJSPEECH_U_SYMBOLS = [_pad] + list(_special) + list(_punctuation) + _arpabet + [_eos]
125
126
# Regular expression matching text enclosed in curly braces:
127
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
128
129
130
_arpa_exempt = _punctuation + _special
131
132
arpa_g2p = grapheme_to_phn()
133
134
135
@dataclass
136
class LJSpeechUltimateProcessor(BaseProcessor):
137
"""LJSpeech Ultimate processor."""
138
139
cleaner_names: str = "english_cleaners"
140
positions = {
141
"wave_file": 0,
142
"text_norm": 1,
143
}
144
train_f_name: str = "filelist.txt"
145
146
def create_items(self):
147
if self.data_dir:
148
with open(
149
os.path.join(self.data_dir, self.train_f_name), encoding="utf-8"
150
) as f:
151
self.items = [self.split_line(self.data_dir, line, "|") for line in f]
152
153
def split_line(self, data_dir, line, split):
154
parts = line.strip().split(split)
155
wave_file = parts[self.positions["wave_file"]]
156
text_norm = parts[self.positions["text_norm"]]
157
wav_path = os.path.join(data_dir, wave_file)
158
speaker_name = "ljspeech"
159
return text_norm, wav_path, speaker_name
160
161
def setup_eos_token(self):
162
return _eos
163
164
def save_pretrained(self, saved_path):
165
os.makedirs(saved_path, exist_ok=True)
166
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})
167
168
def to_arpa(self, in_str):
169
phn_arr = arpa_g2p(in_str)
170
phn_arr = [x for x in phn_arr if x != " "]
171
172
arpa_str = "{"
173
in_chain = True
174
175
# Iterative array-traverse approach to build ARPA string. Phonemes must be in curly braces, but not punctuation
176
for token in phn_arr:
177
if token in _arpa_exempt and in_chain:
178
arpa_str += " }"
179
in_chain = False
180
181
if token not in _arpa_exempt and not in_chain:
182
arpa_str += " {"
183
in_chain = True
184
185
arpa_str += " " + token
186
187
if in_chain:
188
arpa_str += " }"
189
190
return arpa_str
191
192
def get_one_sample(self, item):
193
text, wav_path, speaker_name = item
194
195
# Check if this line is already an ARPA string by searching for the trademark curly brace. If not, we apply
196
if not "{" in text:
197
text = self.to_arpa(text)
198
199
# normalize audio signal to be [-1, 1], soundfile already norm.
200
audio, rate = sf.read(wav_path)
201
audio = audio.astype(np.float32)
202
203
# convert text to ids
204
text_ids = np.asarray(self.text_to_sequence(text), np.int32)
205
206
sample = {
207
"raw_text": text,
208
"text_ids": text_ids,
209
"audio": audio,
210
"utt_id": os.path.split(wav_path)[-1].split(".")[0],
211
"speaker_name": speaker_name,
212
"rate": rate,
213
}
214
215
return sample
216
217
def text_to_sequence(self, text):
218
sequence = []
219
# Check for curly braces and treat their contents as ARPAbet:
220
while len(text):
221
m = _curly_re.match(text)
222
if not m:
223
sequence += self._symbols_to_sequence(
224
self._clean_text(text, [self.cleaner_names])
225
)
226
break
227
sequence += self._symbols_to_sequence(
228
self._clean_text(m.group(1), [self.cleaner_names])
229
)
230
sequence += self._arpabet_to_sequence(m.group(2))
231
text = m.group(3)
232
233
# add eos tokens
234
sequence += [self.eos_id]
235
return sequence
236
237
def _clean_text(self, text, cleaner_names):
238
for name in cleaner_names:
239
cleaner = getattr(cleaners, name)
240
if not cleaner:
241
raise Exception("Unknown cleaner: %s" % name)
242
text = cleaner(text)
243
return text
244
245
def _symbols_to_sequence(self, symbols):
246
return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)]
247
248
def _arpabet_to_sequence(self, text):
249
return self._symbols_to_sequence(["@" + s for s in text.split()])
250
251
def _should_keep_symbol(self, s):
252
return s in self.symbol_to_id and s != "_" and s != "~"
253
254