Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TensorSpeech
GitHub Repository: TensorSpeech/TensorFlowTTS
Path: blob/master/tensorflow_tts/processor/kss.py
1558 views
1
# -*- coding: utf-8 -*-
2
# Copyright 2020 TensorFlowTTS Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Perform preprocessing and raw feature extraction for KSS dataset."""
16
17
import os
18
import re
19
20
import numpy as np
21
import soundfile as sf
22
from dataclasses import dataclass
23
from tensorflow_tts.processor import BaseProcessor
24
from tensorflow_tts.utils import cleaners
25
from tensorflow_tts.utils.korean import symbols as KSS_SYMBOLS
26
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME
27
28
# Regular expression matching text enclosed in curly braces:
29
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
30
31
32
@dataclass
33
class KSSProcessor(BaseProcessor):
34
"""KSS processor."""
35
36
cleaner_names: str = "korean_cleaners"
37
positions = {
38
"wave_file": 0,
39
"text_norm": 2,
40
}
41
train_f_name: str = "transcript.v.1.4.txt"
42
43
def create_items(self):
44
if self.data_dir:
45
with open(
46
os.path.join(self.data_dir, self.train_f_name), encoding="utf-8"
47
) as f:
48
self.items = [self.split_line(self.data_dir, line, "|") for line in f]
49
50
def split_line(self, data_dir, line, split):
51
parts = line.strip().split(split)
52
wave_file = parts[self.positions["wave_file"]]
53
text_norm = parts[self.positions["text_norm"]]
54
wav_path = os.path.join(data_dir, "kss", wave_file)
55
speaker_name = "kss"
56
return text_norm, wav_path, speaker_name
57
58
def setup_eos_token(self):
59
return "eos"
60
61
def save_pretrained(self, saved_path):
62
os.makedirs(saved_path, exist_ok=True)
63
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})
64
65
def get_one_sample(self, item):
66
text, wav_path, speaker_name = item
67
68
# normalize audio signal to be [-1, 1], soundfile already norm.
69
audio, rate = sf.read(wav_path)
70
audio = audio.astype(np.float32)
71
72
# convert text to ids
73
text_ids = np.asarray(self.text_to_sequence(text), np.int32)
74
75
sample = {
76
"raw_text": text,
77
"text_ids": text_ids,
78
"audio": audio,
79
"utt_id": os.path.split(wav_path)[-1].split(".")[0],
80
"speaker_name": speaker_name,
81
"rate": rate,
82
}
83
84
return sample
85
86
def text_to_sequence(self, text):
87
88
sequence = []
89
# Check for curly braces and treat their contents as ARPAbet:
90
while len(text):
91
m = _curly_re.match(text)
92
if not m:
93
sequence += self._symbols_to_sequence(
94
self._clean_text(text, [self.cleaner_names])
95
)
96
break
97
sequence += self._symbols_to_sequence(
98
self._clean_text(m.group(1), [self.cleaner_names])
99
)
100
sequence += self._arpabet_to_sequence(m.group(2))
101
text = m.group(3)
102
103
# add eos tokens
104
sequence += [self.eos_id]
105
return sequence
106
107
def _clean_text(self, text, cleaner_names):
108
for name in cleaner_names:
109
cleaner = getattr(cleaners, name)
110
if not cleaner:
111
raise Exception("Unknown cleaner: %s" % name)
112
text = cleaner(text)
113
return text
114
115
def _symbols_to_sequence(self, symbols):
116
return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)]
117
118
def _arpabet_to_sequence(self, text):
119
return self._symbols_to_sequence(["@" + s for s in text.split()])
120
121
def _should_keep_symbol(self, s):
122
return s in self.symbol_to_id and s != "_" and s != "~"
123
124