Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/processor/ljspeech.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Perform preprocessing and raw feature extraction for LJSpeech dataset."""
16
17
import os
18
import re
19
20
import numpy as np
21
import soundfile as sf
22
from dataclasses import dataclass
23
from tensorflow_tts.processor import BaseProcessor
24
from tensorflow_tts.utils import cleaners
25
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME
26
27
valid_symbols = [
28
"AA",
29
"AA0",
30
"AA1",
31
"AA2",
32
"AE",
33
"AE0",
34
"AE1",
35
"AE2",
36
"AH",
37
"AH0",
38
"AH1",
39
"AH2",
40
"AO",
41
"AO0",
42
"AO1",
43
"AO2",
44
"AW",
45
"AW0",
46
"AW1",
47
"AW2",
48
"AY",
49
"AY0",
50
"AY1",
51
"AY2",
52
"B",
53
"CH",
54
"D",
55
"DH",
56
"EH",
57
"EH0",
58
"EH1",
59
"EH2",
60
"ER",
61
"ER0",
62
"ER1",
63
"ER2",
64
"EY",
65
"EY0",
66
"EY1",
67
"EY2",
68
"F",
69
"G",
70
"HH",
71
"IH",
72
"IH0",
73
"IH1",
74
"IH2",
75
"IY",
76
"IY0",
77
"IY1",
78
"IY2",
79
"JH",
80
"K",
81
"L",
82
"M",
83
"N",
84
"NG",
85
"OW",
86
"OW0",
87
"OW1",
88
"OW2",
89
"OY",
90
"OY0",
91
"OY1",
92
"OY2",
93
"P",
94
"R",
95
"S",
96
"SH",
97
"T",
98
"TH",
99
"UH",
100
"UH0",
101
"UH1",
102
"UH2",
103
"UW",
104
"UW0",
105
"UW1",
106
"UW2",
107
"V",
108
"W",
109
"Y",
110
"Z",
111
"ZH",
112
]
113
114
_pad = "pad"
115
_eos = "eos"
116
_punctuation = "!'(),.:;? "
117
_special = "-"
118
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
119
120
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
121
_arpabet = ["@" + s for s in valid_symbols]
122
123
# Export all symbols:
124
LJSPEECH_SYMBOLS = (
125
[_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet + [_eos]
126
)
127
128
# Regular expression matching text enclosed in curly braces:
129
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
130
131
132
@dataclass
133
class LJSpeechProcessor(BaseProcessor):
134
"""LJSpeech processor."""
135
136
cleaner_names: str = "english_cleaners"
137
positions = {
138
"wave_file": 0,
139
"text": 1,
140
"text_norm": 2,
141
}
142
train_f_name: str = "metadata.csv"
143
144
def create_items(self):
145
if self.data_dir:
146
with open(
147
os.path.join(self.data_dir, self.train_f_name), encoding="utf-8"
148
) as f:
149
self.items = [self.split_line(self.data_dir, line, "|") for line in f]
150
151
def split_line(self, data_dir, line, split):
152
parts = line.strip().split(split)
153
wave_file = parts[self.positions["wave_file"]]
154
text_norm = parts[self.positions["text_norm"]]
155
wav_path = os.path.join(data_dir, "wavs", f"{wave_file}.wav")
156
speaker_name = "ljspeech"
157
return text_norm, wav_path, speaker_name
158
159
def setup_eos_token(self):
160
return _eos
161
162
def save_pretrained(self, saved_path):
163
os.makedirs(saved_path, exist_ok=True)
164
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})
165
166
def get_one_sample(self, item):
167
text, wav_path, speaker_name = item
168
169
# normalize audio signal to be [-1, 1], soundfile already norm.
170
audio, rate = sf.read(wav_path)
171
audio = audio.astype(np.float32)
172
173
# convert text to ids
174
text_ids = np.asarray(self.text_to_sequence(text), np.int32)
175
176
sample = {
177
"raw_text": text,
178
"text_ids": text_ids,
179
"audio": audio,
180
"utt_id": os.path.split(wav_path)[-1].split(".")[0],
181
"speaker_name": speaker_name,
182
"rate": rate,
183
}
184
185
return sample
186
187
def text_to_sequence(self, text):
188
sequence = []
189
# Check for curly braces and treat their contents as ARPAbet:
190
while len(text):
191
m = _curly_re.match(text)
192
if not m:
193
sequence += self._symbols_to_sequence(
194
self._clean_text(text, [self.cleaner_names])
195
)
196
break
197
sequence += self._symbols_to_sequence(
198
self._clean_text(m.group(1), [self.cleaner_names])
199
)
200
sequence += self._arpabet_to_sequence(m.group(2))
201
text = m.group(3)
202
203
# add eos tokens
204
sequence += [self.eos_id]
205
return sequence
206
207
def _clean_text(self, text, cleaner_names):
208
for name in cleaner_names:
209
cleaner = getattr(cleaners, name)
210
if not cleaner:
211
raise Exception("Unknown cleaner: %s" % name)
212
text = cleaner(text)
213
return text
214
215
def _symbols_to_sequence(self, symbols):
216
return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)]
217
218
def _arpabet_to_sequence(self, text):
219
return self._symbols_to_sequence(["@" + s for s in text.split()])
220
221
def _should_keep_symbol(self, s):
222
return s in self.symbol_to_id and s != "_" and s != "~"
223
224