Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/preprocessing/base_binarizer.py
694 views
1
import os
2
from webbrowser import get
3
os.environ["OMP_NUM_THREADS"] = "1"
4
import yaml
5
from utils.multiprocess_utils import chunked_multiprocess_run
6
import random
7
import json
8
# from resemblyzer import VoiceEncoder
9
from tqdm import tqdm
10
from preprocessing.data_gen_utils import get_mel2ph, get_pitch_parselmouth, build_phone_encoder,get_pitch_crepe
11
from utils.hparams import set_hparams, hparams
12
import numpy as np
13
from utils.indexed_datasets import IndexedDatasetBuilder
14
15
16
class BinarizationError(Exception):
17
pass
18
19
BASE_ITEM_ATTRIBUTES = ['txt', 'ph', 'wav_fn', 'tg_fn', 'spk_id']
20
21
class BaseBinarizer:
22
'''
23
Base class for data processing.
24
1. *process* and *process_data_split*:
25
process entire data, generate the train-test split (support parallel processing);
26
2. *process_item*:
27
process singe piece of data;
28
3. *get_pitch*:
29
infer the pitch using some algorithm;
30
4. *get_align*:
31
get the alignment using 'mel2ph' format (see https://arxiv.org/abs/1905.09263).
32
5. phoneme encoder, voice encoder, etc.
33
34
Subclasses should define:
35
1. *load_metadata*:
36
how to read multiple datasets from files;
37
2. *train_item_names*, *valid_item_names*, *test_item_names*:
38
how to split the dataset;
39
3. load_ph_set:
40
the phoneme set.
41
'''
42
def __init__(self, item_attributes=BASE_ITEM_ATTRIBUTES):
43
self.binarization_args = hparams['binarization_args']
44
#self.pre_align_args = hparams['pre_align_args']
45
46
self.items = {}
47
# every item in self.items has some attributes
48
self.item_attributes = item_attributes
49
50
self.load_meta_data()
51
# check program correctness 检查itemdict的key只能在给定的列表中取值
52
assert all([attr in self.item_attributes for attr in list(self.items.values())[0].keys()])
53
self.item_names = sorted(list(self.items.keys()))
54
55
if self.binarization_args['shuffle']:
56
random.seed(1234)
57
random.shuffle(self.item_names)
58
59
# set default get_pitch algorithm
60
if hparams['use_crepe']:
61
self.get_pitch_algorithm = get_pitch_crepe
62
else:
63
self.get_pitch_algorithm = get_pitch_parselmouth
64
65
def load_meta_data(self):
66
raise NotImplementedError
67
68
@property
69
def train_item_names(self):
70
raise NotImplementedError
71
72
@property
73
def valid_item_names(self):
74
raise NotImplementedError
75
76
@property
77
def test_item_names(self):
78
raise NotImplementedError
79
80
def build_spk_map(self):
81
spk_map = set()
82
for item_name in self.item_names:
83
spk_name = self.items[item_name]['spk_id']
84
spk_map.add(spk_name)
85
spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
86
assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
87
return spk_map
88
89
def item_name2spk_id(self, item_name):
90
return self.spk_map[self.items[item_name]['spk_id']]
91
92
def _phone_encoder(self):
93
'''
94
use hubert encoder
95
'''
96
raise NotImplementedError
97
'''
98
create 'phone_set.json' file if it doesn't exist
99
'''
100
ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
101
ph_set = []
102
if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
103
self.load_ph_set(ph_set)
104
ph_set = sorted(set(ph_set))
105
json.dump(ph_set, open(ph_set_fn, 'w', encoding='utf-8'))
106
print("| Build phone set: ", ph_set)
107
else:
108
ph_set = json.load(open(ph_set_fn, 'r', encoding='utf-8'))
109
print("| Load phone set: ", ph_set)
110
return build_phone_encoder(hparams['binary_data_dir'])
111
112
113
def load_ph_set(self, ph_set):
114
raise NotImplementedError
115
116
def meta_data_iterator(self, prefix):
117
if prefix == 'valid':
118
item_names = self.valid_item_names
119
elif prefix == 'test':
120
item_names = self.test_item_names
121
else:
122
item_names = self.train_item_names
123
for item_name in item_names:
124
meta_data = self.items[item_name]
125
yield item_name, meta_data
126
127
def process(self):
128
os.makedirs(hparams['binary_data_dir'], exist_ok=True)
129
self.spk_map = self.build_spk_map()
130
print("| spk_map: ", self.spk_map)
131
spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
132
json.dump(self.spk_map, open(spk_map_fn, 'w', encoding='utf-8'))
133
134
self.phone_encoder =self._phone_encoder()
135
self.process_data_split('valid')
136
self.process_data_split('test')
137
self.process_data_split('train')
138
139
def process_data_split(self, prefix):
140
data_dir = hparams['binary_data_dir']
141
args = []
142
builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
143
lengths = []
144
f0s = []
145
total_sec = 0
146
# if self.binarization_args['with_spk_embed']:
147
# voice_encoder = VoiceEncoder().cuda()
148
149
for item_name, meta_data in self.meta_data_iterator(prefix):
150
args.append([item_name, meta_data, self.binarization_args])
151
spec_min=[]
152
spec_max=[]
153
# code for single cpu processing
154
for i in tqdm(reversed(range(len(args))), total=len(args)):
155
a = args[i]
156
item = self.process_item(*a)
157
if item is None:
158
continue
159
spec_min.append(item['spec_min'])
160
spec_max.append(item['spec_max'])
161
# item['spk_embe'] = voice_encoder.embed_utterance(item['wav']) \
162
# if self.binardization_args['with_spk_embed'] else None
163
if not self.binarization_args['with_wav'] and 'wav' in item:
164
if hparams['debug']:
165
print("del wav")
166
del item['wav']
167
if(hparams['debug']):
168
print(item)
169
builder.add_item(item)
170
lengths.append(item['len'])
171
total_sec += item['sec']
172
# if item.get('f0') is not None:
173
# f0s.append(item['f0'])
174
if prefix=='train':
175
spec_max=np.max(spec_max,0)
176
spec_min=np.min(spec_min,0)
177
print(spec_max.shape)
178
with open(hparams['config_path'], encoding='utf-8') as f:
179
_hparams=yaml.safe_load(f)
180
_hparams['spec_max']=spec_max.tolist()
181
_hparams['spec_min']=spec_min.tolist()
182
with open(hparams['config_path'], 'w', encoding='utf-8') as f:
183
yaml.safe_dump(_hparams,f)
184
builder.finalize()
185
np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
186
if len(f0s) > 0:
187
f0s = np.concatenate(f0s, 0)
188
f0s = f0s[f0s != 0]
189
np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
190
print(f"| {prefix} total duration: {total_sec:.3f}s")
191
192
def process_item(self, item_name, meta_data, binarization_args):
193
from preprocessing.process_pipeline import File2Batch
194
return File2Batch.temporary_dict2processed_input(item_name, meta_data, self.phone_encoder, binarization_args)
195
196
def get_align(self, meta_data, mel, phone_encoded, res):
197
raise NotImplementedError
198
199
def get_align_from_textgrid(self, meta_data, mel, phone_encoded, res):
200
'''
201
NOTE: this part of script is *isolated* from other scripts, which means
202
it may not be compatible with the current version.
203
'''
204
return
205
tg_fn, ph = meta_data['tg_fn'], meta_data['ph']
206
if tg_fn is not None and os.path.exists(tg_fn):
207
mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
208
else:
209
raise BinarizationError(f"Align not found")
210
if mel2ph.max() - 1 >= len(phone_encoded):
211
raise BinarizationError(
212
f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
213
res['mel2ph'] = mel2ph
214
res['dur'] = dur
215
216
def get_f0cwt(self, f0, res):
217
'''
218
NOTE: this part of script is *isolated* from other scripts, which means
219
it may not be compatible with the current version.
220
'''
221
return
222
from utils.cwt import get_cont_lf0, get_lf0_cwt
223
uv, cont_lf0_lpf = get_cont_lf0(f0)
224
logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
225
cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
226
Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
227
if np.any(np.isnan(Wavelet_lf0)):
228
raise BinarizationError("NaN CWT")
229
res['cwt_spec'] = Wavelet_lf0
230
res['cwt_scales'] = scales
231
res['f0_mean'] = logf0s_mean_org
232
res['f0_std'] = logf0s_std_org
233
234
235
if __name__ == "__main__":
236
set_hparams()
237
BaseBinarizer().process()
238
239