Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/training/task/SVC_task.py
694 views
1
import torch
2
3
import utils
4
from utils.hparams import hparams
5
from network.diff.net import DiffNet
6
from network.diff.diffusion import GaussianDiffusion, OfflineGaussianDiffusion
7
from training.task.fs2 import FastSpeech2Task
8
from network.vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
9
from modules.fastspeech.tts_modules import mel2ph_to_dur
10
11
from network.diff.candidate_decoder import FFT
12
from utils.pitch_utils import denorm_f0
13
from training.dataset.fs2_utils import FastSpeechDataset
14
15
import numpy as np
16
import os
17
import torch.nn.functional as F
18
19
DIFF_DECODERS = {
20
'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
21
'fft': lambda hp: FFT(
22
hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
23
}
24
25
26
class SVCDataset(FastSpeechDataset):
27
def collater(self, samples):
28
from preprocessing.process_pipeline import File2Batch
29
return File2Batch.processed_input2batch(samples)
30
31
32
class SVCTask(FastSpeech2Task):
33
def __init__(self):
34
super(SVCTask, self).__init__()
35
self.dataset_cls = SVCDataset
36
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
37
38
def build_tts_model(self):
39
# import torch
40
# from tqdm import tqdm
41
# v_min = torch.ones([80]) * 100
42
# v_max = torch.ones([80]) * -100
43
# for i, ds in enumerate(tqdm(self.dataset_cls('train'))):
44
# v_max = torch.max(torch.max(ds['mel'].reshape(-1, 80), 0)[0], v_max)
45
# v_min = torch.min(torch.min(ds['mel'].reshape(-1, 80), 0)[0], v_min)
46
# if i % 100 == 0:
47
# print(i, v_min, v_max)
48
# print('final', v_min, v_max)
49
mel_bins = hparams['audio_num_mel_bins']
50
self.model = GaussianDiffusion(
51
phone_encoder=self.phone_encoder,
52
out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
53
timesteps=hparams['timesteps'],
54
K_step=hparams['K_step'],
55
loss_type=hparams['diff_loss_type'],
56
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
57
)
58
59
60
def build_optimizer(self, model):
61
self.optimizer = optimizer = torch.optim.AdamW(
62
filter(lambda p: p.requires_grad, model.parameters()),
63
lr=hparams['lr'],
64
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
65
weight_decay=hparams['weight_decay'])
66
return optimizer
67
68
def run_model(self, model, sample, return_output=False, infer=False):
69
'''
70
steps:
71
1. run the full model, calc the main loss
72
2. calculate loss for dur_predictor, pitch_predictor, energy_predictor
73
'''
74
hubert = sample['hubert'] # [B, T_t,H]
75
target = sample['mels'] # [B, T_s, 80]
76
mel2ph = sample['mel2ph'] # [B, T_s]
77
f0 = sample['f0']
78
uv = sample['uv']
79
energy = sample['energy']
80
81
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
82
if hparams['pitch_type'] == 'cwt':
83
# NOTE: this part of script is *isolated* from other scripts, which means
84
# it may not be compatible with the current version.
85
pass
86
# cwt_spec = sample[f'cwt_spec']
87
# f0_mean = sample['f0_mean']
88
# f0_std = sample['f0_std']
89
# sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
90
91
# output == ret
92
# model == src.diff.diffusion.GaussianDiffusion
93
output = model(hubert, mel2ph=mel2ph, spk_embed=spk_embed,
94
ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
95
96
losses = {}
97
if 'diff_loss' in output:
98
losses['mel'] = output['diff_loss']
99
#self.add_dur_loss(output['dur'], mel2ph, txt_tokens, sample['word_boundary'], losses=losses)
100
# if hparams['use_pitch_embed']:
101
# self.add_pitch_loss(output, sample, losses)
102
# if hparams['use_energy_embed']:
103
# self.add_energy_loss(output['energy_pred'], energy, losses)
104
if not return_output:
105
return losses
106
else:
107
return losses, output
108
109
def _training_step(self, sample, batch_idx, _):
110
log_outputs = self.run_model(self.model, sample)
111
total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad])
112
log_outputs['batch_size'] = sample['hubert'].size()[0]
113
log_outputs['lr'] = self.scheduler.get_lr()[0]
114
return total_loss, log_outputs
115
116
def build_scheduler(self, optimizer):
117
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
118
119
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
120
if optimizer is None:
121
return
122
optimizer.step()
123
optimizer.zero_grad()
124
if self.scheduler is not None:
125
self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
126
127
def validation_step(self, sample, batch_idx):
128
outputs = {}
129
hubert = sample['hubert'] # [B, T_t]
130
131
target = sample['mels'] # [B, T_s, 80]
132
energy = sample['energy']
133
# fs2_mel = sample['fs2_mels']
134
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
135
mel2ph = sample['mel2ph']
136
137
outputs['losses'] = {}
138
139
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
140
141
outputs['total_loss'] = sum(outputs['losses'].values())
142
outputs['nsamples'] = sample['nsamples']
143
outputs = utils.tensors_to_scalars(outputs)
144
if batch_idx < hparams['num_valid_plots']:
145
model_out = self.model(
146
hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=sample['f0'], uv=sample['uv'], energy=energy, ref_mels=None, infer=True
147
)
148
149
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
150
gt_f0 = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel
151
pred_f0 = self.pe(model_out['mel_out'])['f0_denorm_pred'] # pe predict from Pred mel
152
else:
153
gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
154
pred_f0 = model_out.get('f0_denorm')
155
self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=pred_f0)
156
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'], name=f'diffmel_{batch_idx}')
157
#self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'], name=f'fs2mel_{batch_idx}')
158
if hparams['use_pitch_embed']:
159
self.plot_pitch(batch_idx, sample, model_out)
160
return outputs
161
162
def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, wdb, losses=None):
163
"""
164
the effect of each loss component:
165
hparams['dur_loss'] : align each phoneme
166
hparams['lambda_word_dur']: align each word
167
hparams['lambda_sent_dur']: align each sentence
168
169
:param dur_pred: [B, T], float, log scale
170
:param mel2ph: [B, T]
171
:param txt_tokens: [B, T]
172
:param losses:
173
:return:
174
"""
175
B, T = txt_tokens.shape
176
nonpadding = (txt_tokens != 0).float()
177
dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding
178
is_sil = torch.zeros_like(txt_tokens).bool()
179
for p in self.sil_ph:
180
is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0])
181
is_sil = is_sil.float() # [B, T_txt]
182
183
# phone duration loss
184
if hparams['dur_loss'] == 'mse':
185
losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none')
186
losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum()
187
losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur']
188
dur_pred = (dur_pred.exp() - 1).clamp(min=0)
189
else:
190
raise NotImplementedError
191
192
# use linear scale for sent and word duration
193
if hparams['lambda_word_dur'] > 0:
194
#idx = F.pad(wdb.cumsum(axis=1), (1, 0))[:, :-1]
195
idx = wdb.cumsum(axis=1)
196
# word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_(1, idx, midi_dur) # midi_dur can be implied by add gt-ph_dur
197
word_dur_p = dur_pred.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_pred)
198
word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_gt)
199
wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none')
200
word_nonpadding = (word_dur_g > 0).float()
201
wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum()
202
losses['wdur'] = wdur_loss * hparams['lambda_word_dur']
203
if hparams['lambda_sent_dur'] > 0:
204
sent_dur_p = dur_pred.sum(-1)
205
sent_dur_g = dur_gt.sum(-1)
206
sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean')
207
losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur']
208
209
############
210
# validation plots
211
############
212
def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None):
213
gt_wav = gt_wav[0].cpu().numpy()
214
wav_out = wav_out[0].cpu().numpy()
215
gt_f0 = gt_f0[0].cpu().numpy()
216
f0 = f0[0].cpu().numpy()
217
if is_mel:
218
gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0)
219
wav_out = self.vocoder.spec2wav(wav_out, f0=f0)
220
self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
221
self.logger.experiment.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
222
223
224
225