Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/training/dataset/fs2_utils.py
694 views
1
import matplotlib
2
3
matplotlib.use('Agg')
4
5
import glob
6
import importlib
7
from utils.cwt import get_lf0_cwt
8
import os
9
import torch.optim
10
import torch.utils.data
11
from utils.indexed_datasets import IndexedDataset
12
from utils.pitch_utils import norm_interp_f0
13
import numpy as np
14
from training.dataset.base_dataset import BaseDataset
15
import torch
16
import torch.optim
17
import torch.utils.data
18
import utils
19
import torch.distributions
20
from utils.hparams import hparams
21
22
23
class FastSpeechDataset(BaseDataset):
24
def __init__(self, prefix, shuffle=False):
25
super().__init__(shuffle)
26
self.data_dir = hparams['binary_data_dir']
27
self.prefix = prefix
28
self.hparams = hparams
29
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
30
self.indexed_ds = None
31
# self.name2spk_id={}
32
33
# pitch stats
34
f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
35
if os.path.exists(f0_stats_fn):
36
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
37
hparams['f0_mean'] = float(hparams['f0_mean'])
38
hparams['f0_std'] = float(hparams['f0_std'])
39
else:
40
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
41
42
if prefix == 'test':
43
if hparams['test_input_dir'] != '':
44
self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
45
else:
46
if hparams['num_test_samples'] > 0:
47
self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
48
self.sizes = [self.sizes[i] for i in self.avail_idxs]
49
50
if hparams['pitch_type'] == 'cwt':
51
_, hparams['cwt_scales'] = get_lf0_cwt(np.ones(10))
52
53
def _get_item(self, index):
54
if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
55
index = self.avail_idxs[index]
56
if self.indexed_ds is None:
57
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
58
return self.indexed_ds[index]
59
60
def __getitem__(self, index):
61
hparams = self.hparams
62
item = self._get_item(index)
63
max_frames = hparams['max_frames']
64
spec = torch.Tensor(item['mel'])[:max_frames]
65
energy = (spec.exp() ** 2).sum(-1).sqrt()
66
mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
67
f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
68
#phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']])
69
hubert=torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
70
pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
71
# print(item.keys(), item['mel'].shape, spec.shape)
72
sample = {
73
"id": index,
74
"item_name": item['item_name'],
75
# "text": item['txt'],
76
# "txt_token": phone,
77
"hubert":hubert,
78
"mel": spec,
79
"pitch": pitch,
80
"energy": energy,
81
"f0": f0,
82
"uv": uv,
83
"mel2ph": mel2ph,
84
"mel_nonpadding": spec.abs().sum(-1) > 0,
85
}
86
if self.hparams['use_spk_embed']:
87
sample["spk_embed"] = torch.Tensor(item['spk_embed'])
88
if self.hparams['use_spk_id']:
89
sample["spk_id"] = item['spk_id']
90
# sample['spk_id'] = 0
91
# for key in self.name2spk_id.keys():
92
# if key in item['item_name']:
93
# sample['spk_id'] = self.name2spk_id[key]
94
# break
95
#======not used==========
96
# if self.hparams['pitch_type'] == 'cwt':
97
# cwt_spec = torch.Tensor(item['cwt_spec'])[:max_frames]
98
# f0_mean = item.get('f0_mean', item.get('cwt_mean'))
99
# f0_std = item.get('f0_std', item.get('cwt_std'))
100
# sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std})
101
# elif self.hparams['pitch_type'] == 'ph':
102
# f0_phlevel_sum = torch.zeros_like(phone).float().scatter_add(0, mel2ph - 1, f0)
103
# f0_phlevel_num = torch.zeros_like(phone).float().scatter_add(
104
# 0, mel2ph - 1, torch.ones_like(f0)).clamp_min(1)
105
# sample["f0_ph"] = f0_phlevel_sum / f0_phlevel_num
106
return sample
107
108
def collater(self, samples):
109
if len(samples) == 0:
110
return {}
111
id = torch.LongTensor([s['id'] for s in samples])
112
item_names = [s['item_name'] for s in samples]
113
text = [s['text'] for s in samples]
114
txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
115
f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
116
pitch = utils.collate_1d([s['pitch'] for s in samples],1)
117
uv = utils.collate_1d([s['uv'] for s in samples])
118
energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
119
mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
120
if samples[0]['mel2ph'] is not None else None
121
mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
122
txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
123
mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
124
125
batch = {
126
'id': id,
127
'item_name': item_names,
128
'nsamples': len(samples),
129
'text': text,
130
'txt_tokens': txt_tokens,
131
'txt_lengths': txt_lengths,
132
'mels': mels,
133
'mel_lengths': mel_lengths,
134
'mel2ph': mel2ph,
135
'energy': energy,
136
'pitch': pitch,
137
'f0': f0,
138
'uv': uv,
139
}
140
141
if self.hparams['use_spk_embed']:
142
spk_embed = torch.stack([s['spk_embed'] for s in samples])
143
batch['spk_embed'] = spk_embed
144
if self.hparams['use_spk_id']:
145
spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
146
batch['spk_ids'] = spk_ids
147
if self.hparams['pitch_type'] == 'cwt':
148
cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples])
149
f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
150
f0_std = torch.Tensor([s['f0_std'] for s in samples])
151
batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
152
elif self.hparams['pitch_type'] == 'ph':
153
batch['f0'] = utils.collate_1d([s['f0_ph'] for s in samples])
154
155
return batch
156
157
def load_test_inputs(self, test_input_dir, spk_id=0):
158
inp_wav_paths = glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/*.mp3')
159
sizes = []
160
items = []
161
162
binarizer_cls = hparams.get("binarizer_cls", 'basics.base_binarizer.BaseBinarizer')
163
pkg = ".".join(binarizer_cls.split(".")[:-1])
164
cls_name = binarizer_cls.split(".")[-1]
165
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
166
binarization_args = hparams['binarization_args']
167
from preprocessing.hubertinfer import Hubertencoder
168
for wav_fn in inp_wav_paths:
169
item_name = os.path.basename(wav_fn)
170
ph = txt = tg_fn = ''
171
wav_fn = wav_fn
172
encoder = Hubertencoder(hparams['hubert_path'])
173
174
item = binarizer_cls.process_item(item_name, {'wav_fn':wav_fn}, encoder, binarization_args)
175
print(item)
176
items.append(item)
177
sizes.append(item['len'])
178
return items, sizes
179
180