Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/training/pe.py
694 views
1
import matplotlib
2
matplotlib.use('Agg')
3
4
import torch
5
import numpy as np
6
import os
7
8
from training.dataset.base_dataset import BaseDataset
9
from training.task.fs2 import FastSpeech2Task
10
from modules.fastspeech.pe import PitchExtractor
11
import utils
12
from utils.indexed_datasets import IndexedDataset
13
from utils.hparams import hparams
14
from utils.plot import f0_to_figure
15
from utils.pitch_utils import norm_interp_f0, denorm_f0
16
17
18
class PeDataset(BaseDataset):
19
def __init__(self, prefix, shuffle=False):
20
super().__init__(shuffle)
21
self.data_dir = hparams['binary_data_dir']
22
self.prefix = prefix
23
self.hparams = hparams
24
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
25
self.indexed_ds = None
26
27
# pitch stats
28
f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
29
if os.path.exists(f0_stats_fn):
30
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
31
hparams['f0_mean'] = float(hparams['f0_mean'])
32
hparams['f0_std'] = float(hparams['f0_std'])
33
else:
34
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
35
36
if prefix == 'test':
37
if hparams['num_test_samples'] > 0:
38
self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
39
self.sizes = [self.sizes[i] for i in self.avail_idxs]
40
41
def _get_item(self, index):
42
if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
43
index = self.avail_idxs[index]
44
if self.indexed_ds is None:
45
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
46
return self.indexed_ds[index]
47
48
def __getitem__(self, index):
49
hparams = self.hparams
50
item = self._get_item(index)
51
max_frames = hparams['max_frames']
52
spec = torch.Tensor(item['mel'])[:max_frames]
53
# mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
54
f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
55
pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
56
# print(item.keys(), item['mel'].shape, spec.shape)
57
sample = {
58
"id": index,
59
"item_name": item['item_name'],
60
"text": item['txt'],
61
"mel": spec,
62
"pitch": pitch,
63
"f0": f0,
64
"uv": uv,
65
# "mel2ph": mel2ph,
66
# "mel_nonpadding": spec.abs().sum(-1) > 0,
67
}
68
return sample
69
70
def collater(self, samples):
71
if len(samples) == 0:
72
return {}
73
id = torch.LongTensor([s['id'] for s in samples])
74
item_names = [s['item_name'] for s in samples]
75
text = [s['text'] for s in samples]
76
f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
77
pitch = utils.collate_1d([s['pitch'] for s in samples])
78
uv = utils.collate_1d([s['uv'] for s in samples])
79
mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
80
mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
81
# mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
82
# if samples[0]['mel2ph'] is not None else None
83
# mel_nonpaddings = utils.collate_1d([s['mel_nonpadding'].float() for s in samples], 0.0)
84
85
batch = {
86
'id': id,
87
'item_name': item_names,
88
'nsamples': len(samples),
89
'text': text,
90
'mels': mels,
91
'mel_lengths': mel_lengths,
92
'pitch': pitch,
93
# 'mel2ph': mel2ph,
94
# 'mel_nonpaddings': mel_nonpaddings,
95
'f0': f0,
96
'uv': uv,
97
}
98
return batch
99
100
101
class PitchExtractionTask(FastSpeech2Task):
102
def __init__(self):
103
super().__init__()
104
self.dataset_cls = PeDataset
105
106
def build_tts_model(self):
107
self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers'])
108
109
# def build_scheduler(self, optimizer):
110
# return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
111
def _training_step(self, sample, batch_idx, _):
112
loss_output = self.run_model(self.model, sample)
113
total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
114
loss_output['batch_size'] = sample['mels'].size()[0]
115
return total_loss, loss_output
116
117
def validation_step(self, sample, batch_idx):
118
outputs = {}
119
outputs['losses'] = {}
120
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True)
121
outputs['total_loss'] = sum(outputs['losses'].values())
122
outputs['nsamples'] = sample['nsamples']
123
outputs = utils.tensors_to_scalars(outputs)
124
if batch_idx < hparams['num_valid_plots']:
125
self.plot_pitch(batch_idx, model_out, sample)
126
return outputs
127
128
def run_model(self, model, sample, return_output=False, infer=False):
129
f0 = sample['f0']
130
uv = sample['uv']
131
output = model(sample['mels'])
132
losses = {}
133
self.add_pitch_loss(output, sample, losses)
134
if not return_output:
135
return losses
136
else:
137
return losses, output
138
139
def plot_pitch(self, batch_idx, model_out, sample):
140
gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
141
self.logger.experiment.add_figure(
142
f'f0_{batch_idx}',
143
f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]),
144
self.global_step)
145
146
def add_pitch_loss(self, output, sample, losses):
147
# mel2ph = sample['mel2ph'] # [B, T_s]
148
mel = sample['mels']
149
f0 = sample['f0']
150
uv = sample['uv']
151
# nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \
152
# else (sample['txt_tokens'] != 0).float()
153
nonpadding = (mel.abs().sum(-1) > 0).float() # sample['mel_nonpaddings']
154
# print(nonpadding[0][-8:], nonpadding.shape)
155
self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding)
156