Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/preprocessing/process_pipeline.py
694 views
1
'''
2
file -> temporary_dict -> processed_input -> batch
3
'''
4
from utils.hparams import hparams
5
from network.vocoders.base_vocoder import VOCODERS
6
import numpy as np
7
import traceback
8
from pathlib import Path
9
from .data_gen_utils import get_pitch_parselmouth,get_pitch_crepe
10
from .base_binarizer import BinarizationError
11
import torch
12
import utils
13
14
class File2Batch:
15
'''
16
pipeline: file -> temporary_dict -> processed_input -> batch
17
'''
18
19
@staticmethod
20
def file2temporary_dict():
21
'''
22
read from file, store data in temporary dicts
23
'''
24
raw_data_dir = Path(hparams['raw_data_dir'])
25
# meta_midi = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict]
26
27
# if hparams['perform_enhance'] and not hparams['infer']:
28
# vocoder=get_vocoder_cls(hparams)()
29
# raw_files = list(raw_data_dir.rglob(f"*.wav"))
30
# dic=[]
31
# time_step = hparams['hop_size'] / hparams['audio_sample_rate']
32
# f0_min = hparams['f0_min']
33
# f0_max = hparams['f0_max']
34
# for file in raw_files:
35
# y, sr = librosa.load(file, sr=hparams['audio_sample_rate'])
36
# f0 = parselmouth.Sound(y, hparams['audio_sample_rate']).to_pitch_ac(
37
# time_step=time_step , voicing_threshold=0.6,
38
# pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
39
# f0_mean=np.mean(f0[f0>0])
40
# dic.append(f0_mean)
41
# for idx in np.where(dic>np.percentile(dic, 80))[0]:
42
# file=raw_files[idx]
43
# wav,mel=vocoder.wav2spec(str(file))
44
# f0,_=get_pitch_parselmouth(wav,mel,hparams)
45
# f0[f0>0]=f0[f0>0]*(2**(2/12))
46
# wav_pred=vocoder.spec2wav(torch.FloatTensor(mel),f0=torch.FloatTensor(f0))
47
# sf.write(file.with_name(file.name[:-4]+'_high.wav'), wav_pred, 24000, 'PCM_16')
48
utterance_labels =[]
49
utterance_labels.extend(list(raw_data_dir.rglob(f"*.wav")))
50
utterance_labels.extend(list(raw_data_dir.rglob(f"*.ogg")))
51
#open(os.path.join(raw_data_dir, 'transcriptions.txt'), encoding='utf-8').readlines()
52
53
all_temp_dict = {}
54
for utterance_label in utterance_labels:
55
#song_info = utterance_label.split('|')
56
item_name =str(utterance_label)#raw_item_name = song_info[0]
57
# print(item_name)
58
temp_dict = {}
59
temp_dict['wav_fn'] =str(utterance_label)#f'{raw_data_dir}/wavs/{item_name}.wav'
60
# temp_dict['txt'] = song_info[1]
61
62
# temp_dict['ph'] = song_info[2]
63
# # self.item2wdb[item_name] = list(np.nonzero([1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in song_info[2].split()])[0])
64
# temp_dict['word_boundary'] = np.array([1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in song_info[2].split()])
65
# temp_dict['ph_durs'] = [float(x) for x in song_info[5].split(" ")]
66
67
# temp_dict['pitch_midi'] = np.array([note_to_midi(x.split("/")[0]) if x != 'rest' else 0
68
# for x in song_info[3].split(" ")])
69
# temp_dict['midi_dur'] = np.array([float(x) for x in song_info[4].split(" ")])
70
# temp_dict['is_slur'] = np.array([int(x) for x in song_info[6].split(" ")])
71
temp_dict['spk_id'] = hparams['speaker_id']
72
# assert temp_dict['pitch_midi'].shape == temp_dict['midi_dur'].shape == temp_dict['is_slur'].shape, \
73
# (temp_dict['pitch_midi'].shape, temp_dict['midi_dur'].shape, temp_dict['is_slur'].shape)
74
75
all_temp_dict[item_name] = temp_dict
76
77
return all_temp_dict
78
79
@staticmethod
80
def temporary_dict2processed_input(item_name, temp_dict, encoder, binarization_args):
81
'''
82
process data in temporary_dicts
83
'''
84
def get_pitch(wav, mel):
85
# get ground truth f0 by self.get_pitch_algorithm
86
if hparams['use_crepe']:
87
gt_f0, gt_pitch_coarse = get_pitch_crepe(wav, mel, hparams)
88
else:
89
gt_f0, gt_pitch_coarse = get_pitch_parselmouth(wav, mel, hparams)
90
if sum(gt_f0) == 0:
91
raise BinarizationError("Empty **gt** f0")
92
processed_input['f0'] = gt_f0
93
processed_input['pitch'] = gt_pitch_coarse
94
95
def get_align(meta_data, mel, phone_encoded, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']):
96
mel2ph = np.zeros([mel.shape[0]], int)
97
start_frame=0
98
ph_durs = mel.shape[0]/phone_encoded.shape[0]
99
if hparams['debug']:
100
print(mel.shape,phone_encoded.shape,mel.shape[0]/phone_encoded.shape[0])
101
for i_ph in range(phone_encoded.shape[0]):
102
103
end_frame = int(i_ph*ph_durs +ph_durs+ 0.5)
104
mel2ph[start_frame:end_frame+1] = i_ph + 1
105
start_frame = end_frame+1
106
107
processed_input['mel2ph'] = mel2ph
108
109
if hparams['vocoder'] in VOCODERS:
110
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(temp_dict['wav_fn'])
111
else:
112
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(temp_dict['wav_fn'])
113
processed_input = {
114
'item_name': item_name, 'mel': mel, 'wav': wav,
115
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0]
116
}
117
processed_input = {**temp_dict, **processed_input} # merge two dicts
118
processed_input['spec_min']=np.min(mel,axis=0)
119
processed_input['spec_max']=np.max(mel,axis=0)
120
#(processed_input['spec_min'].shape)
121
try:
122
if binarization_args['with_f0']:
123
get_pitch(wav, mel)
124
if binarization_args['with_hubert']:
125
try:
126
hubert_encoded = processed_input['hubert'] = encoder.encode(temp_dict['wav_fn'])
127
except:
128
traceback.print_exc()
129
raise Exception(f"hubert encode error")
130
if binarization_args['with_align']:
131
get_align(temp_dict, mel, hubert_encoded)
132
except Exception as e:
133
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {temp_dict['wav_fn']}")
134
return None
135
return processed_input
136
137
@staticmethod
138
def processed_input2batch(samples):
139
'''
140
Args:
141
samples: one batch of processed_input
142
NOTE:
143
the batch size is controlled by hparams['max_sentences']
144
'''
145
if len(samples) == 0:
146
return {}
147
id = torch.LongTensor([s['id'] for s in samples])
148
item_names = [s['item_name'] for s in samples]
149
#text = [s['text'] for s in samples]
150
#txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
151
hubert = utils.collate_2d([s['hubert'] for s in samples], 0.0)
152
f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
153
pitch = utils.collate_1d([s['pitch'] for s in samples])
154
uv = utils.collate_1d([s['uv'] for s in samples])
155
energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
156
mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
157
if samples[0]['mel2ph'] is not None else None
158
mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
159
#txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
160
hubert_lengths = torch.LongTensor([s['hubert'].shape[0] for s in samples])
161
mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
162
163
batch = {
164
'id': id,
165
'item_name': item_names,
166
'nsamples': len(samples),
167
# 'text': text,
168
# 'txt_tokens': txt_tokens,
169
# 'txt_lengths': txt_lengths,
170
'hubert':hubert,
171
'mels': mels,
172
'mel_lengths': mel_lengths,
173
'mel2ph': mel2ph,
174
'energy': energy,
175
'pitch': pitch,
176
'f0': f0,
177
'uv': uv,
178
}
179
#========not used=================
180
# if hparams['use_spk_embed']:
181
# spk_embed = torch.stack([s['spk_embed'] for s in samples])
182
# batch['spk_embed'] = spk_embed
183
# if hparams['use_spk_id']:
184
# spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
185
# batch['spk_ids'] = spk_ids
186
# if hparams['pitch_type'] == 'cwt':
187
# cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples])
188
# f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
189
# f0_std = torch.Tensor([s['f0_std'] for s in samples])
190
# batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
191
# elif hparams['pitch_type'] == 'ph':
192
# batch['f0'] = utils.collate_1d([s['f0_ph'] for s in samples])
193
194
# batch['pitch_midi'] = utils.collate_1d([s['pitch_midi'] for s in samples], 0)
195
# batch['midi_dur'] = utils.collate_1d([s['midi_dur'] for s in samples], 0)
196
# batch['is_slur'] = utils.collate_1d([s['is_slur'] for s in samples], 0)
197
# batch['word_boundary'] = utils.collate_1d([s['word_boundary'] for s in samples], 0)
198
199
return batch
200