Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/fastspeech/fs2.py
694 views
1
from modules.commons.common_layers import *
2
from modules.commons.common_layers import Embedding
3
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
4
EnergyPredictor, FastspeechEncoder
5
from utils.cwt import cwt2f0
6
from utils.hparams import hparams
7
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
8
9
FS_ENCODERS = {
10
'fft': lambda hp: FastspeechEncoder(
11
hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
12
num_heads=hp['num_heads']),
13
}
14
15
FS_DECODERS = {
16
'fft': lambda hp: FastspeechDecoder(
17
hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
18
}
19
20
21
class FastSpeech2(nn.Module):
22
def __init__(self, dictionary, out_dims=None):
23
super().__init__()
24
# self.dictionary = dictionary
25
self.padding_idx = 0
26
if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
27
self.enc_layers = hparams['enc_layers']
28
self.dec_layers = hparams['dec_layers']
29
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams)
30
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
31
self.hidden_size = hparams['hidden_size']
32
# self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
33
self.out_dims = out_dims
34
if out_dims is None:
35
self.out_dims = hparams['audio_num_mel_bins']
36
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
37
#=========not used===========
38
# if hparams['use_spk_id']:
39
# self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
40
# if hparams['use_split_spk_id']:
41
# self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
42
# self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
43
# elif hparams['use_spk_embed']:
44
# self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
45
predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
46
# self.dur_predictor = DurationPredictor(
47
# self.hidden_size,
48
# n_chans=predictor_hidden,
49
# n_layers=hparams['dur_predictor_layers'],
50
# dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
51
# kernel_size=hparams['dur_predictor_kernel'])
52
# self.length_regulator = LengthRegulator()
53
if hparams['use_pitch_embed']:
54
self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
55
if hparams['pitch_type'] == 'cwt':
56
h = hparams['cwt_hidden_size']
57
cwt_out_dims = 10
58
if hparams['use_uv']:
59
cwt_out_dims = cwt_out_dims + 1
60
self.cwt_predictor = nn.Sequential(
61
nn.Linear(self.hidden_size, h),
62
PitchPredictor(
63
h,
64
n_chans=predictor_hidden,
65
n_layers=hparams['predictor_layers'],
66
dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
67
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
68
self.cwt_stats_layers = nn.Sequential(
69
nn.Linear(self.hidden_size, h), nn.ReLU(),
70
nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
71
)
72
else:
73
self.pitch_predictor = PitchPredictor(
74
self.hidden_size,
75
n_chans=predictor_hidden,
76
n_layers=hparams['predictor_layers'],
77
dropout_rate=hparams['predictor_dropout'],
78
odim=2 if hparams['pitch_type'] == 'frame' else 1,
79
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
80
if hparams['use_energy_embed']:
81
self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
82
# self.energy_predictor = EnergyPredictor(
83
# self.hidden_size,
84
# n_chans=predictor_hidden,
85
# n_layers=hparams['predictor_layers'],
86
# dropout_rate=hparams['predictor_dropout'], odim=1,
87
# padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
88
89
# def build_embedding(self, dictionary, embed_dim):
90
# num_embeddings = len(dictionary)
91
# emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
92
# return emb
93
94
def forward(self, hubert, mel2ph=None, spk_embed=None,
95
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=True,
96
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
97
ret = {}
98
if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
99
encoder_out =self.encoder(hubert) # [B, T, C]
100
else:
101
encoder_out =hubert
102
src_nonpadding = (hubert!=0).any(-1)[:,:,None]
103
104
# add ref style embed
105
# Not implemented
106
# variance encoder
107
var_embed = 0
108
109
# encoder_out_dur denotes encoder outputs for duration predictor
110
# in speech adaptation, duration predictor use old speaker embedding
111
if hparams['use_spk_embed']:
112
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
113
elif hparams['use_spk_id']:
114
spk_embed_id = spk_embed
115
if spk_embed_dur_id is None:
116
spk_embed_dur_id = spk_embed_id
117
if spk_embed_f0_id is None:
118
spk_embed_f0_id = spk_embed_id
119
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
120
spk_embed_dur = spk_embed_f0 = spk_embed
121
if hparams['use_split_spk_id']:
122
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
123
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
124
else:
125
spk_embed_dur = spk_embed_f0 = spk_embed = 0
126
127
# add dur
128
# dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
129
130
# mel2ph = self.add_dur(dur_inp, mel2ph, hubert, ret)
131
ret['mel2ph'] = mel2ph
132
133
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
134
135
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
136
decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
137
138
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
139
140
# add pitch and energy embed
141
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
142
if hparams['use_pitch_embed']:
143
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
144
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
145
if hparams['use_energy_embed']:
146
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
147
148
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
149
if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
150
if skip_decoder:
151
return ret
152
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
153
154
return ret
155
156
def add_dur(self, dur_input, mel2ph, hubert, ret):
157
src_padding = (hubert==0).all(-1)
158
dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
159
if mel2ph is None:
160
dur, xs = self.dur_predictor.inference(dur_input, src_padding)
161
ret['dur'] = xs
162
ret['dur_choice'] = dur
163
mel2ph = self.length_regulator(dur, src_padding).detach()
164
else:
165
ret['dur'] = self.dur_predictor(dur_input, src_padding)
166
ret['mel2ph'] = mel2ph
167
return mel2ph
168
169
def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
170
x = decoder_inp # [B, T, H]
171
x = self.decoder(x)
172
x = self.mel_out(x)
173
return x * tgt_nonpadding
174
175
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
176
f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
177
f0 = torch.cat(
178
[f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
179
f0_norm = norm_f0(f0, None, hparams)
180
return f0_norm
181
182
def out2mel(self, out):
183
return out
184
185
def add_pitch(self,decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
186
# if hparams['pitch_type'] == 'ph':
187
# pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
188
# pitch_padding = (encoder_out.sum().abs() == 0)
189
# ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
190
# if f0 is None:
191
# f0 = pitch_pred[:, :, 0]
192
# ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
193
# pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
194
# pitch = F.pad(pitch, [1, 0])
195
# pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
196
# pitch_embedding = pitch_embed(pitch)
197
# return pitch_embedding
198
199
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
200
201
pitch_padding = (mel2ph == 0)
202
203
# if hparams['pitch_type'] == 'cwt':
204
# # NOTE: this part of script is *isolated* from other scripts, which means
205
# # it may not be compatible with the current version.
206
# pass
207
# # pitch_padding = None
208
# # ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
209
# # stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
210
# # mean = ret['f0_mean'] = stats_out[:, 0]
211
# # std = ret['f0_std'] = stats_out[:, 1]
212
# # cwt_spec = cwt_out[:, :, :10]
213
# # if f0 is None:
214
# # std = std * hparams['cwt_std_scale']
215
# # f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
216
# # if hparams['use_uv']:
217
# # assert cwt_out.shape[-1] == 11
218
# # uv = cwt_out[:, :, -1] > 0
219
# elif hparams['pitch_ar']:
220
# ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if is_training else None)
221
# if f0 is None:
222
# f0 = pitch_pred[:, :, 0]
223
# else:
224
#ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
225
# if f0 is None:
226
# f0 = pitch_pred[:, :, 0]
227
# if hparams['use_uv'] and uv is None:
228
# uv = pitch_pred[:, :, 1] > 0
229
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
230
if pitch_padding is not None:
231
f0[pitch_padding] = 0
232
233
pitch = f0_to_coarse(f0_denorm,hparams) # start from 0
234
ret['pitch_pred']=pitch.unsqueeze(-1)
235
# print(ret['pitch_pred'].shape)
236
# print(pitch.shape)
237
pitch_embedding = self.pitch_embed(pitch)
238
return pitch_embedding
239
240
def add_energy(self,decoder_inp, energy, ret):
241
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
242
ret['energy_pred'] = energy#energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
243
# if energy is None:
244
# energy = energy_pred
245
energy = torch.clamp(energy * 256 // 4, max=255).long() # energy_to_coarse
246
energy_embedding = self.energy_embed(energy)
247
return energy_embedding
248
249
@staticmethod
250
def mel_norm(x):
251
return (x + 5.5) / (6.3 / 2) - 1
252
253
@staticmethod
254
def mel_denorm(x):
255
return (x + 1) * (6.3 / 2) - 5.5
256
257