Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/training/train_pipeline.py
694 views
1
from utils.hparams import hparams
2
import torch
3
from torch.nn import functional as F
4
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
5
6
class Batch2Loss:
7
'''
8
pipeline: batch -> insert1 -> module1 -> insert2 -> module2 -> insert3 -> module3 -> insert4 -> module4 -> loss
9
'''
10
11
@staticmethod
12
def insert1(pitch_midi, midi_dur, is_slur, # variables
13
midi_embed, midi_dur_layer, is_slur_embed): # modules
14
'''
15
add embeddings for midi, midi_dur, slur
16
'''
17
midi_embedding = midi_embed(pitch_midi)
18
midi_dur_embedding, slur_embedding = 0, 0
19
if midi_dur is not None:
20
midi_dur_embedding = midi_dur_layer(midi_dur[:, :, None]) # [B, T, 1] -> [B, T, H]
21
if is_slur is not None:
22
slur_embedding = is_slur_embed(is_slur)
23
return midi_embedding, midi_dur_embedding, slur_embedding
24
25
@staticmethod
26
def module1(fs2_encoder, # modules
27
txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding): # variables
28
'''
29
get *encoder_out* == fs2_encoder(*txt_tokens*, some embeddings)
30
'''
31
encoder_out = fs2_encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding)
32
return encoder_out
33
34
@staticmethod
35
def insert2(encoder_out, spk_embed_id, spk_embed_dur_id, spk_embed_f0_id, src_nonpadding, # variables
36
spk_embed_proj): # modules
37
'''
38
1. add embeddings for pspk, spk_dur, sk_f0
39
2. get *dur_inp* ~= *encoder_out* + *spk_embed_dur*
40
'''
41
# add ref style embed
42
# Not implemented
43
# variance encoder
44
var_embed = 0
45
46
# encoder_out_dur denotes encoder outputs for duration predictor
47
# in speech adaptation, duration predictor use old speaker embedding
48
if hparams['use_spk_embed']:
49
spk_embed_dur = spk_embed_f0 = spk_embed = spk_embed_proj(spk_embed_id)[:, None, :]
50
elif hparams['use_spk_id']:
51
if spk_embed_dur_id is None:
52
spk_embed_dur_id = spk_embed_id
53
if spk_embed_f0_id is None:
54
spk_embed_f0_id = spk_embed_id
55
spk_embed = spk_embed_proj(spk_embed_id)[:, None, :]
56
spk_embed_dur = spk_embed_f0 = spk_embed
57
if hparams['use_split_spk_id']:
58
spk_embed_dur = spk_embed_dur(spk_embed_dur_id)[:, None, :]
59
spk_embed_f0 = spk_embed_f0(spk_embed_f0_id)[:, None, :]
60
else:
61
spk_embed_dur = spk_embed_f0 = spk_embed = 0
62
63
# add dur
64
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
65
return var_embed, spk_embed, spk_embed_dur, spk_embed_f0, dur_inp
66
67
@staticmethod
68
def module2(dur_predictor, length_regulator, # modules
69
dur_input, mel2ph, txt_tokens, all_vowel_tokens, ret, midi_dur=None): # variables
70
'''
71
1. get *dur* ~= dur_predictor(*dur_inp*)
72
2. (mel2ph is None): get *mel2ph* ~= length_regulater(*dur*)
73
'''
74
src_padding = (txt_tokens == 0)
75
dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
76
77
if mel2ph is None:
78
dur, xs = dur_predictor.inference(dur_input, src_padding)
79
ret['dur'] = xs
80
dur = xs.squeeze(-1).exp() - 1.0
81
for i in range(len(dur)):
82
for j in range(len(dur[i])):
83
if txt_tokens[i,j] in all_vowel_tokens:
84
if j < len(dur[i])-1 and txt_tokens[i,j+1] not in all_vowel_tokens:
85
dur[i,j] = midi_dur[i,j] - dur[i,j+1]
86
if dur[i,j] < 0:
87
dur[i,j] = 0
88
dur[i,j+1] = midi_dur[i,j]
89
else:
90
dur[i,j]=midi_dur[i,j]
91
dur[:,0] = dur[:,0] + 0.5
92
dur_acc = F.pad(torch.round(torch.cumsum(dur, axis=1)), (1,0))
93
dur = torch.clamp(dur_acc[:,1:]-dur_acc[:,:-1], min=0).long()
94
ret['dur_choice'] = dur
95
mel2ph = length_regulator(dur, src_padding).detach()
96
else:
97
ret['dur'] = dur_predictor(dur_input, src_padding)
98
ret['mel2ph'] = mel2ph
99
100
return mel2ph
101
102
@staticmethod
103
def insert3(encoder_out, mel2ph, var_embed, spk_embed_f0, src_nonpadding, tgt_nonpadding): # variables
104
'''
105
1. get *decoder_inp* ~= gather *encoder_out* according to *mel2ph*
106
2. get *pitch_inp* ~= *decoder_inp* + *spk_embed_f0*
107
3. get *pitch_inp_ph* ~= *encoder_out* + *spk_embed_f0*
108
'''
109
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
110
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
111
decoder_inp = decoder_inp_origin = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
112
113
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
114
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
115
return decoder_inp, pitch_inp, pitch_inp_ph
116
117
@staticmethod
118
def module3(pitch_predictor, pitch_embed, energy_predictor, energy_embed, # modules
119
pitch_inp, pitch_inp_ph, f0, uv, energy, mel2ph, is_training, ret): # variables
120
'''
121
1. get *ret['pitch_pred']*, *ret['energy_pred']* ~= pitch_predictor(*pitch_inp*), energy_predictor(*pitch_inp*)
122
2. get *pitch_embedding* ~= pitch_embed(f0_to_coarse(denorm_f0(*f0* or *pitch_pred*))
123
3. get *energy_embedding* ~= energy_embed(energy_to_coarse(*energy* or *energy_pred*))
124
'''
125
def add_pitch(decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
126
if hparams['pitch_type'] == 'ph':
127
pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
128
pitch_padding = (encoder_out.sum().abs() == 0)
129
ret['pitch_pred'] = pitch_pred = pitch_predictor(pitch_pred_inp)
130
if f0 is None:
131
f0 = pitch_pred[:, :, 0]
132
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
133
pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
134
pitch = F.pad(pitch, [1, 0])
135
pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
136
pitch_embedding = pitch_embed(pitch)
137
return pitch_embedding
138
139
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
140
141
pitch_padding = (mel2ph == 0)
142
143
if hparams['pitch_type'] == 'cwt':
144
# NOTE: this part of script is *isolated* from other scripts, which means
145
# it may not be compatible with the current version.
146
pass
147
# pitch_padding = None
148
# ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
149
# stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
150
# mean = ret['f0_mean'] = stats_out[:, 0]
151
# std = ret['f0_std'] = stats_out[:, 1]
152
# cwt_spec = cwt_out[:, :, :10]
153
# if f0 is None:
154
# std = std * hparams['cwt_std_scale']
155
# f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
156
# if hparams['use_uv']:
157
# assert cwt_out.shape[-1] == 11
158
# uv = cwt_out[:, :, -1] > 0
159
elif hparams['pitch_ar']:
160
ret['pitch_pred'] = pitch_pred = pitch_predictor(decoder_inp, f0 if is_training else None)
161
if f0 is None:
162
f0 = pitch_pred[:, :, 0]
163
else:
164
ret['pitch_pred'] = pitch_pred = pitch_predictor(decoder_inp)
165
if f0 is None:
166
f0 = pitch_pred[:, :, 0]
167
if hparams['use_uv'] and uv is None:
168
uv = pitch_pred[:, :, 1] > 0
169
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
170
if pitch_padding is not None:
171
f0[pitch_padding] = 0
172
173
pitch = f0_to_coarse(f0_denorm) # start from 0
174
pitch_embedding = pitch_embed(pitch)
175
return pitch_embedding
176
177
def add_energy(decoder_inp, energy, ret):
178
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
179
ret['energy_pred'] = energy_pred = energy_predictor(decoder_inp)[:, :, 0]
180
if energy is None:
181
energy = energy_pred
182
energy = torch.clamp(energy * 256 // 4, max=255).long() # energy_to_coarse
183
energy_embedding = energy_embed(energy)
184
return energy_embedding
185
186
# add pitch and energy embed
187
nframes = mel2ph.size(1)
188
189
pitch_embedding = 0
190
if hparams['use_pitch_embed']:
191
if f0 is not None:
192
delta_l = nframes - f0.size(1)
193
if delta_l > 0:
194
f0 = torch.cat((f0,torch.FloatTensor([[x[-1]] * delta_l for x in f0]).to(f0.device)),1)
195
f0 = f0[:,:nframes]
196
if uv is not None:
197
delta_l = nframes - uv.size(1)
198
if delta_l > 0:
199
uv = torch.cat((uv,torch.FloatTensor([[x[-1]] * delta_l for x in uv]).to(uv.device)),1)
200
uv = uv[:,:nframes]
201
pitch_embedding = add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
202
203
energy_embedding = 0
204
if hparams['use_energy_embed']:
205
if energy is not None:
206
delta_l = nframes - energy.size(1)
207
if delta_l > 0:
208
energy = torch.cat((energy,torch.FloatTensor([[x[-1]] * delta_l for x in energy]).to(energy.device)),1)
209
energy = energy[:,:nframes]
210
energy_embedding = add_energy(pitch_inp, energy, ret)
211
212
return pitch_embedding, energy_embedding
213
214
@staticmethod
215
def insert4(decoder_inp, pitch_embedding, energy_embedding, spk_embed, ret, tgt_nonpadding):
216
'''
217
*decoder_inp* ~= *decoder_inp* + embeddings for spk, pitch, energy
218
'''
219
ret['decoder_inp'] = decoder_inp = (decoder_inp + pitch_embedding + energy_embedding + spk_embed) * tgt_nonpadding
220
return decoder_inp
221
222
@staticmethod
223
def module4(diff_main_loss, # modules
224
norm_spec, decoder_inp_t, ret, K_step, batch_size, device): # variables
225
'''
226
training diffusion using spec as input and decoder_inp as condition.
227
228
Args:
229
norm_spec: (normalized) spec
230
decoder_inp_t: (transposed) decoder_inp
231
Returns:
232
ret['diff_loss']
233
'''
234
t = torch.randint(0, K_step, (batch_size,), device=device).long()
235
norm_spec = norm_spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
236
ret['diff_loss'] = diff_main_loss(norm_spec, t, cond=decoder_inp_t)
237
# nonpadding = (mel2ph != 0).float()
238
# ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
239
240