Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/training/task/fs2.py
694 views
1
import matplotlib
2
3
matplotlib.use('Agg')
4
5
from utils import audio
6
import matplotlib.pyplot as plt
7
from preprocessing.data_gen_utils import get_pitch_parselmouth
8
from training.dataset.fs2_utils import FastSpeechDataset
9
from utils.cwt import cwt2f0
10
from utils.pl_utils import data_loader
11
import os
12
from multiprocessing.pool import Pool
13
from tqdm import tqdm
14
from modules.fastspeech.tts_modules import mel2ph_to_dur
15
from utils.hparams import hparams
16
from utils.plot import spec_to_figure, dur_to_figure, f0_to_figure
17
from utils.pitch_utils import denorm_f0
18
from modules.fastspeech.fs2 import FastSpeech2
19
from training.task.tts import TtsTask
20
import torch
21
import torch.optim
22
import torch.utils.data
23
import torch.nn.functional as F
24
import utils
25
import torch.distributions
26
import numpy as np
27
from modules.commons.ssim import ssim
28
29
class FastSpeech2Task(TtsTask):
30
def __init__(self):
31
super(FastSpeech2Task, self).__init__()
32
self.dataset_cls = FastSpeechDataset
33
self.mse_loss_fn = torch.nn.MSELoss()
34
mel_losses = hparams['mel_loss'].split("|")
35
self.loss_and_lambda = {}
36
for i, l in enumerate(mel_losses):
37
if l == '':
38
continue
39
if ':' in l:
40
l, lbd = l.split(":")
41
lbd = float(lbd)
42
else:
43
lbd = 1.0
44
self.loss_and_lambda[l] = lbd
45
print("| Mel losses:", self.loss_and_lambda)
46
#self.sil_ph = self.phone_encoder.sil_phonemes()
47
48
@data_loader
49
def train_dataloader(self):
50
train_dataset = self.dataset_cls(hparams['train_set_name'], shuffle=True)
51
return self.build_dataloader(train_dataset, True, self.max_tokens, self.max_sentences,
52
endless=hparams['endless_ds'])
53
54
@data_loader
55
def val_dataloader(self):
56
valid_dataset = self.dataset_cls(hparams['valid_set_name'], shuffle=False)
57
return self.build_dataloader(valid_dataset, False, self.max_eval_tokens, self.max_eval_sentences)
58
59
@data_loader
60
def test_dataloader(self):
61
test_dataset = self.dataset_cls(hparams['test_set_name'], shuffle=False)
62
return self.build_dataloader(test_dataset, False, self.max_eval_tokens,
63
self.max_eval_sentences, batch_by_size=False)
64
65
def build_tts_model(self):
66
'''
67
rewrite
68
'''
69
return
70
# self.model = FastSpeech2(self.phone_encoder)
71
72
def build_model(self):
73
self.build_tts_model()
74
if hparams['load_ckpt'] != '':
75
self.load_ckpt(hparams['load_ckpt'], strict=True)
76
utils.print_arch(self.model)
77
return self.model
78
79
def _training_step(self, sample, batch_idx, _):
80
'''
81
rewrite
82
'''
83
return
84
# loss_output = self.run_model(self.model, sample)
85
# total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
86
# loss_output['batch_size'] = sample['txt_tokens'].size()[0]
87
# return total_loss, loss_output
88
89
def validation_step(self, sample, batch_idx):
90
'''
91
rewrite
92
'''
93
return
94
# outputs = {}
95
# outputs['losses'] = {}
96
# outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True)
97
# outputs['total_loss'] = sum(outputs['losses'].values())
98
# outputs['nsamples'] = sample['nsamples']
99
# mel_out = self.model.out2mel(model_out['mel_out'])
100
# outputs = utils.tensors_to_scalars(outputs)
101
# if batch_idx < hparams['num_valid_plots']:
102
# self.plot_mel(batch_idx, sample['mels'], mel_out)
103
# self.plot_dur(batch_idx, sample, model_out)
104
# if hparams['use_pitch_embed']:
105
# self.plot_pitch(batch_idx, sample, model_out)
106
# return outputs
107
108
def _validation_end(self, outputs):
109
all_losses_meter = {
110
'total_loss': utils.AvgrageMeter(),
111
}
112
for output in outputs:
113
n = output['nsamples']
114
for k, v in output['losses'].items():
115
if k not in all_losses_meter:
116
all_losses_meter[k] = utils.AvgrageMeter()
117
all_losses_meter[k].update(v, n)
118
all_losses_meter['total_loss'].update(output['total_loss'], n)
119
return {k: round(v.avg, 4) for k, v in all_losses_meter.items()}
120
121
def run_model(self, model, sample, return_output=False):
122
'''
123
rewrite
124
'''
125
return
126
txt_tokens = sample['txt_tokens'] # [B, T_t]
127
target = sample['mels'] # [B, T_s, 80]
128
mel2ph = sample['mel2ph'] # [B, T_s]
129
f0 = sample['f0']
130
uv = sample['uv']
131
energy = sample['energy']
132
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
133
if hparams['pitch_type'] == 'cwt':
134
cwt_spec = sample[f'cwt_spec']
135
f0_mean = sample['f0_mean']
136
f0_std = sample['f0_std']
137
sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
138
139
output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
140
ref_mels=target, f0=f0, uv=uv, energy=energy, infer=False)
141
142
losses = {}
143
self.add_mel_loss(output['mel_out'], target, losses)
144
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
145
if hparams['use_pitch_embed']:
146
self.add_pitch_loss(output, sample, losses)
147
if hparams['use_energy_embed']:
148
self.add_energy_loss(output['energy_pred'], energy, losses)
149
if not return_output:
150
return losses
151
else:
152
return losses, output
153
154
############
155
# losses
156
############
157
def add_mel_loss(self, mel_out, target, losses, postfix='', mel_mix_loss=None):
158
if mel_mix_loss is None:
159
for loss_name, lbd in self.loss_and_lambda.items():
160
if 'l1' == loss_name:
161
l = self.l1_loss(mel_out, target)
162
elif 'mse' == loss_name:
163
raise NotImplementedError
164
elif 'ssim' == loss_name:
165
l = self.ssim_loss(mel_out, target)
166
elif 'gdl' == loss_name:
167
raise NotImplementedError
168
losses[f'{loss_name}{postfix}'] = l * lbd
169
else:
170
raise NotImplementedError
171
172
def l1_loss(self, decoder_output, target):
173
# decoder_output : B x T x n_mel
174
# target : B x T x n_mel
175
l1_loss = F.l1_loss(decoder_output, target, reduction='none')
176
weights = self.weights_nonzero_speech(target)
177
l1_loss = (l1_loss * weights).sum() / weights.sum()
178
return l1_loss
179
180
def ssim_loss(self, decoder_output, target, bias=6.0):
181
# decoder_output : B x T x n_mel
182
# target : B x T x n_mel
183
assert decoder_output.shape == target.shape
184
weights = self.weights_nonzero_speech(target)
185
decoder_output = decoder_output[:, None] + bias
186
target = target[:, None] + bias
187
ssim_loss = 1 - ssim(decoder_output, target, size_average=False)
188
ssim_loss = (ssim_loss * weights).sum() / weights.sum()
189
return ssim_loss
190
191
def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, losses=None):
192
"""
193
194
:param dur_pred: [B, T], float, log scale
195
:param mel2ph: [B, T]
196
:param txt_tokens: [B, T]
197
:param losses:
198
:return:
199
"""
200
B, T = txt_tokens.shape
201
nonpadding = (txt_tokens != 0).float()
202
dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding
203
is_sil = torch.zeros_like(txt_tokens).bool()
204
for p in self.sil_ph:
205
is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0])
206
is_sil = is_sil.float() # [B, T_txt]
207
208
# phone duration loss
209
if hparams['dur_loss'] == 'mse':
210
losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none')
211
losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum()
212
dur_pred = (dur_pred.exp() - 1).clamp(min=0)
213
elif hparams['dur_loss'] == 'mog':
214
return NotImplementedError
215
elif hparams['dur_loss'] == 'crf':
216
losses['pdur'] = -self.model.dur_predictor.crf(
217
dur_pred, dur_gt.long().clamp(min=0, max=31), mask=nonpadding > 0, reduction='mean')
218
losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur']
219
220
# use linear scale for sent and word duration
221
if hparams['lambda_word_dur'] > 0:
222
word_id = (is_sil.cumsum(-1) * (1 - is_sil)).long()
223
word_dur_p = dur_pred.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_pred)[:, 1:]
224
word_dur_g = dur_gt.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_gt)[:, 1:]
225
wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none')
226
word_nonpadding = (word_dur_g > 0).float()
227
wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum()
228
losses['wdur'] = wdur_loss * hparams['lambda_word_dur']
229
if hparams['lambda_sent_dur'] > 0:
230
sent_dur_p = dur_pred.sum(-1)
231
sent_dur_g = dur_gt.sum(-1)
232
sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean')
233
losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur']
234
235
def add_pitch_loss(self, output, sample, losses):
236
if hparams['pitch_type'] == 'ph':
237
nonpadding = (sample['txt_tokens'] != 0).float()
238
pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss
239
losses['f0'] = (pitch_loss_fn(output['pitch_pred'][:, :, 0], sample['f0'],
240
reduction='none') * nonpadding).sum() \
241
/ nonpadding.sum() * hparams['lambda_f0']
242
return
243
mel2ph = sample['mel2ph'] # [B, T_s]
244
f0 = sample['f0']
245
uv = sample['uv']
246
nonpadding = (mel2ph != 0).float()
247
if hparams['pitch_type'] == 'cwt':
248
cwt_spec = sample[f'cwt_spec']
249
f0_mean = sample['f0_mean']
250
f0_std = sample['f0_std']
251
cwt_pred = output['cwt'][:, :, :10]
252
f0_mean_pred = output['f0_mean']
253
f0_std_pred = output['f0_std']
254
losses['C'] = self.cwt_loss(cwt_pred, cwt_spec) * hparams['lambda_f0']
255
if hparams['use_uv']:
256
assert output['cwt'].shape[-1] == 11
257
uv_pred = output['cwt'][:, :, -1]
258
losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none') * nonpadding) \
259
.sum() / nonpadding.sum() * hparams['lambda_uv']
260
losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0']
261
losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0']
262
if hparams['cwt_add_f0_loss']:
263
f0_cwt_ = self.model.cwt2f0_norm(cwt_pred, f0_mean_pred, f0_std_pred, mel2ph)
264
self.add_f0_loss(f0_cwt_[:, :, None], f0, uv, losses, nonpadding=nonpadding)
265
elif hparams['pitch_type'] == 'frame':
266
self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding)
267
268
def add_f0_loss(self, p_pred, f0, uv, losses, nonpadding):
269
assert p_pred[..., 0].shape == f0.shape
270
if hparams['use_uv']:
271
assert p_pred[..., 1].shape == uv.shape
272
losses['uv'] = (F.binary_cross_entropy_with_logits(
273
p_pred[:, :, 1], uv, reduction='none') * nonpadding).sum() \
274
/ nonpadding.sum() * hparams['lambda_uv']
275
nonpadding = nonpadding * (uv == 0).float()
276
277
f0_pred = p_pred[:, :, 0]
278
if hparams['pitch_loss'] in ['l1', 'l2']:
279
pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss
280
losses['f0'] = (pitch_loss_fn(f0_pred, f0, reduction='none') * nonpadding).sum() \
281
/ nonpadding.sum() * hparams['lambda_f0']
282
elif hparams['pitch_loss'] == 'ssim':
283
return NotImplementedError
284
285
def cwt_loss(self, cwt_p, cwt_g):
286
if hparams['cwt_loss'] == 'l1':
287
return F.l1_loss(cwt_p, cwt_g)
288
if hparams['cwt_loss'] == 'l2':
289
return F.mse_loss(cwt_p, cwt_g)
290
if hparams['cwt_loss'] == 'ssim':
291
return self.ssim_loss(cwt_p, cwt_g, 20)
292
293
def add_energy_loss(self, energy_pred, energy, losses):
294
nonpadding = (energy != 0).float()
295
loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum()
296
loss = loss * hparams['lambda_energy']
297
losses['e'] = loss
298
299
300
############
301
# validation plots
302
############
303
def plot_mel(self, batch_idx, spec, spec_out, name=None):
304
spec_cat = torch.cat([spec, spec_out], -1)
305
name = f'mel_{batch_idx}' if name is None else name
306
vmin = hparams['mel_vmin']
307
vmax = hparams['mel_vmax']
308
self.logger.experiment.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step)
309
310
def plot_dur(self, batch_idx, sample, model_out):
311
T_txt = sample['txt_tokens'].shape[1]
312
dur_gt = mel2ph_to_dur(sample['mel2ph'], T_txt)[0]
313
dur_pred = self.model.dur_predictor.out2dur(model_out['dur']).float()
314
txt = self.phone_encoder.decode(sample['txt_tokens'][0].cpu().numpy())
315
txt = txt.split(" ")
316
self.logger.experiment.add_figure(
317
f'dur_{batch_idx}', dur_to_figure(dur_gt, dur_pred, txt), self.global_step)
318
319
def plot_pitch(self, batch_idx, sample, model_out):
320
f0 = sample['f0']
321
if hparams['pitch_type'] == 'ph':
322
mel2ph = sample['mel2ph']
323
f0 = self.expand_f0_ph(f0, mel2ph)
324
f0_pred = self.expand_f0_ph(model_out['pitch_pred'][:, :, 0], mel2ph)
325
self.logger.experiment.add_figure(
326
f'f0_{batch_idx}', f0_to_figure(f0[0], None, f0_pred[0]), self.global_step)
327
return
328
f0 = denorm_f0(f0, sample['uv'], hparams)
329
if hparams['pitch_type'] == 'cwt':
330
# cwt
331
cwt_out = model_out['cwt']
332
cwt_spec = cwt_out[:, :, :10]
333
cwt = torch.cat([cwt_spec, sample['cwt_spec']], -1)
334
self.logger.experiment.add_figure(f'cwt_{batch_idx}', spec_to_figure(cwt[0]), self.global_step)
335
# f0
336
f0_pred = cwt2f0(cwt_spec, model_out['f0_mean'], model_out['f0_std'], hparams['cwt_scales'])
337
if hparams['use_uv']:
338
assert cwt_out.shape[-1] == 11
339
uv_pred = cwt_out[:, :, -1] > 0
340
f0_pred[uv_pred > 0] = 0
341
f0_cwt = denorm_f0(sample['f0_cwt'], sample['uv'], hparams)
342
self.logger.experiment.add_figure(
343
f'f0_{batch_idx}', f0_to_figure(f0[0], f0_cwt[0], f0_pred[0]), self.global_step)
344
elif hparams['pitch_type'] == 'frame':
345
# f0
346
#uv_pred = model_out['pitch_pred'][:, :, 0] > 0
347
pitch_pred = denorm_f0(model_out['pitch_pred'][:, :, 0], sample['uv'], hparams)
348
self.logger.experiment.add_figure(
349
f'f0_{batch_idx}', f0_to_figure(f0[0], None, pitch_pred[0]), self.global_step)
350
351
############
352
# infer
353
############
354
def test_step(self, sample, batch_idx):
355
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
356
hubert = sample['hubert']
357
mel2ph, uv, f0 = None, None, None
358
ref_mels = None
359
if hparams['profile_infer']:
360
pass
361
else:
362
# if hparams['use_gt_dur']:
363
mel2ph = sample['mel2ph']
364
#if hparams['use_gt_f0']:
365
f0 = sample['f0']
366
uv = sample['uv']
367
#print('Here using gt f0!!')
368
if hparams.get('use_midi') is not None and hparams['use_midi']:
369
outputs = self.model(
370
hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True)
371
else:
372
outputs = self.model(
373
hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True)
374
sample['outputs'] = self.model.out2mel(outputs['mel_out'])
375
sample['mel2ph_pred'] = outputs['mel2ph']
376
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
377
sample['f0'] = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel
378
sample['f0_pred'] = self.pe(sample['outputs'])['f0_denorm_pred'] # pe predict from Pred mel
379
else:
380
sample['f0'] = denorm_f0(sample['f0'], sample['uv'], hparams)
381
sample['f0_pred'] = outputs.get('f0_denorm')
382
return self.after_infer(sample)
383
384
def after_infer(self, predictions):
385
if self.saving_result_pool is None and not hparams['profile_infer']:
386
self.saving_result_pool = Pool(min(int(os.getenv('N_PROC', os.cpu_count())), 16))
387
self.saving_results_futures = []
388
predictions = utils.unpack_dict_to_list(predictions)
389
t = tqdm(predictions)
390
for num_predictions, prediction in enumerate(t):
391
for k, v in prediction.items():
392
if type(v) is torch.Tensor:
393
prediction[k] = v.cpu().numpy()
394
395
item_name = prediction.get('item_name')
396
#text = prediction.get('text').replace(":", "%3A")[:80]
397
398
# remove paddings
399
mel_gt = prediction["mels"]
400
mel_gt_mask = np.abs(mel_gt).sum(-1) > 0
401
mel_gt = mel_gt[mel_gt_mask]
402
mel2ph_gt = prediction.get("mel2ph")
403
mel2ph_gt = mel2ph_gt[mel_gt_mask] if mel2ph_gt is not None else None
404
mel_pred = prediction["outputs"]
405
mel_pred_mask = np.abs(mel_pred).sum(-1) > 0
406
mel_pred = mel_pred[mel_pred_mask]
407
mel_gt = np.clip(mel_gt, hparams['mel_vmin'], hparams['mel_vmax'])
408
mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax'])
409
410
mel2ph_pred = prediction.get("mel2ph_pred")
411
if mel2ph_pred is not None:
412
if len(mel2ph_pred) > len(mel_pred_mask):
413
mel2ph_pred = mel2ph_pred[:len(mel_pred_mask)]
414
mel2ph_pred = mel2ph_pred[mel_pred_mask]
415
416
f0_gt = prediction.get("f0")
417
f0_pred = f0_gt#prediction.get("f0_pred")
418
if f0_pred is not None:
419
f0_gt = f0_gt[mel_gt_mask]
420
if len(f0_pred) > len(mel_pred_mask):
421
f0_pred = f0_pred[:len(mel_pred_mask)]
422
f0_pred = f0_pred[mel_pred_mask]
423
text=None
424
str_phs = None
425
# if self.phone_encoder is not None and 'txt_tokens' in prediction:
426
# str_phs = self.phone_encoder.decode(prediction['txt_tokens'], strip_padding=True)
427
# def resize2d(source, target_len):
428
# source[source<0.001] = np.nan
429
# target = np.interp(np.linspace(0, len(source)-1, num=target_len,endpoint=True), np.arange(0, len(source)), source)
430
# return np.nan_to_num(target)
431
# def resize3d(source, target_len):
432
# newsource=[]
433
# for i in range(source.shape[1]):
434
# newsource.append(resize2d(source[:,i],target_len))
435
# return np.array(newsource).transpose()
436
# print(mel_pred.shape)
437
# print(f0_pred.shape)
438
# mel_pred=resize3d(mel_pred,int(mel_pred.shape[0]/44100*24000))
439
# f0_pred=resize2d(f0_pred,int(f0_pred.shape[0]/44100*24000))
440
# print(mel_pred.shape)
441
# print(f0_pred.shape)
442
gen_dir = os.path.join(hparams['work_dir'],
443
f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
444
wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
445
if not hparams['profile_infer']:
446
os.makedirs(gen_dir, exist_ok=True)
447
os.makedirs(f'{gen_dir}/wavs', exist_ok=True)
448
os.makedirs(f'{gen_dir}/plot', exist_ok=True)
449
os.makedirs(os.path.join(hparams['work_dir'], 'P_mels_npy'), exist_ok=True)
450
os.makedirs(os.path.join(hparams['work_dir'], 'G_mels_npy'), exist_ok=True)
451
self.saving_results_futures.append(
452
self.saving_result_pool.apply_async(self.save_result, args=[
453
wav_pred, mel_pred, 'P', item_name, text, gen_dir, str_phs, mel2ph_pred, f0_gt, f0_pred]))
454
455
if mel_gt is not None and hparams['save_gt']:
456
wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt)
457
self.saving_results_futures.append(
458
self.saving_result_pool.apply_async(self.save_result, args=[
459
wav_gt, mel_gt, 'G', item_name, text, gen_dir, str_phs, mel2ph_gt, f0_gt, f0_pred]))
460
if hparams['save_f0']:
461
import matplotlib.pyplot as plt
462
# f0_pred_, _ = get_pitch(wav_pred, mel_pred, hparams)
463
f0_pred_ = f0_pred
464
f0_gt_, _ = get_pitch_parselmouth(wav_gt, mel_gt, hparams)
465
fig = plt.figure()
466
plt.plot(f0_pred_, label=r'$f0_P$')
467
plt.plot(f0_gt_, label=r'$f0_G$')
468
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
469
# f0_midi = prediction.get("f0_midi")
470
# f0_midi = f0_midi[mel_gt_mask]
471
# plt.plot(f0_midi, label=r'$f0_M$')
472
pass
473
plt.legend()
474
plt.tight_layout()
475
plt.savefig(f'{gen_dir}/plot/[F0][{item_name}]{text}.png', format='png')
476
plt.close(fig)
477
478
t.set_description(
479
f"Pred_shape: {mel_pred.shape}, gt_shape: {mel_gt.shape}")
480
else:
481
if 'gen_wav_time' not in self.stats:
482
self.stats['gen_wav_time'] = 0
483
self.stats['gen_wav_time'] += len(wav_pred) / hparams['audio_sample_rate']
484
print('gen_wav_time: ', self.stats['gen_wav_time'])
485
486
return {}
487
488
@staticmethod
489
def save_result(wav_out, mel, prefix, item_name, text, gen_dir, str_phs=None, mel2ph=None, gt_f0=None, pred_f0=None):
490
item_name = item_name.replace('/', '-')
491
base_fn = f'[{item_name}][{prefix}]'
492
493
if text is not None:
494
base_fn += text
495
base_fn += ('-' + hparams['exp_name'])
496
np.save(os.path.join(hparams['work_dir'], f'{prefix}_mels_npy', item_name), mel)
497
audio.save_wav(wav_out, f'{gen_dir}/wavs/{base_fn}.wav', 24000,#hparams['audio_sample_rate'],
498
norm=hparams['out_wav_norm'])
499
fig = plt.figure(figsize=(14, 10))
500
spec_vmin = hparams['mel_vmin']
501
spec_vmax = hparams['mel_vmax']
502
heatmap = plt.pcolor(mel.T, vmin=spec_vmin, vmax=spec_vmax)
503
fig.colorbar(heatmap)
504
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
505
gt_f0 = (gt_f0 - 100) / (800 - 100) * 80 * (gt_f0 > 0)
506
pred_f0 = (pred_f0 - 100) / (800 - 100) * 80 * (pred_f0 > 0)
507
plt.plot(pred_f0, c='white', linewidth=1, alpha=0.6)
508
plt.plot(gt_f0, c='red', linewidth=1, alpha=0.6)
509
else:
510
f0, _ = get_pitch_parselmouth(wav_out, mel, hparams)
511
f0 = (f0 - 100) / (800 - 100) * 80 * (f0 > 0)
512
plt.plot(f0, c='white', linewidth=1, alpha=0.6)
513
if mel2ph is not None and str_phs is not None:
514
decoded_txt = str_phs.split(" ")
515
dur = mel2ph_to_dur(torch.LongTensor(mel2ph)[None, :], len(decoded_txt))[0].numpy()
516
dur = [0] + list(np.cumsum(dur))
517
for i in range(len(dur) - 1):
518
shift = (i % 20) + 1
519
plt.text(dur[i], shift, decoded_txt[i])
520
plt.hlines(shift, dur[i], dur[i + 1], colors='b' if decoded_txt[i] != '|' else 'black')
521
plt.vlines(dur[i], 0, 5, colors='b' if decoded_txt[i] != '|' else 'black',
522
alpha=1, linewidth=1)
523
plt.tight_layout()
524
plt.savefig(f'{gen_dir}/plot/{base_fn}.png', format='png', dpi=1000)
525
plt.close(fig)
526
527
##############
528
# utils
529
##############
530
@staticmethod
531
def expand_f0_ph(f0, mel2ph):
532
f0 = denorm_f0(f0, None, hparams)
533
f0 = F.pad(f0, [1, 0])
534
f0 = torch.gather(f0, 1, mel2ph) # [B, T_mel]
535
return f0
536
537
538
if __name__ == '__main__':
539
FastSpeech2Task.start()
540
541