Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/network/hubert/hubert_model.py
694 views
1
import copy
2
import os
3
import random
4
from typing import Optional, Tuple
5
6
import librosa
7
import numpy as np
8
import torch
9
import torch.nn as nn
10
import torch.nn.functional as t_func
11
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
12
13
from utils import hparams
14
15
16
class Hubert(nn.Module):
17
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
18
super().__init__()
19
self._mask = mask
20
self.feature_extractor = FeatureExtractor()
21
self.feature_projection = FeatureProjection()
22
self.positional_embedding = PositionalConvEmbedding()
23
self.norm = nn.LayerNorm(768)
24
self.dropout = nn.Dropout(0.1)
25
self.encoder = TransformerEncoder(
26
nn.TransformerEncoderLayer(
27
768, 12, 3072, activation="gelu", batch_first=True
28
),
29
12,
30
)
31
self.proj = nn.Linear(768, 256)
32
33
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
34
self.label_embedding = nn.Embedding(num_label_embeddings, 256)
35
36
def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
37
mask = None
38
if self.training and self._mask:
39
mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
40
x[mask] = self.masked_spec_embed.to(x.dtype)
41
return x, mask
42
43
def encode(
44
self, x: torch.Tensor, layer: Optional[int] = None
45
) -> Tuple[torch.Tensor, torch.Tensor]:
46
x = self.feature_extractor(x)
47
x = self.feature_projection(x.transpose(1, 2))
48
x, mask = self.mask(x)
49
x = x + self.positional_embedding(x)
50
x = self.dropout(self.norm(x))
51
x = self.encoder(x, output_layer=layer)
52
return x, mask
53
54
def logits(self, x: torch.Tensor) -> torch.Tensor:
55
logits = torch.cosine_similarity(
56
x.unsqueeze(2),
57
self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
58
dim=-1,
59
)
60
return logits / 0.1
61
62
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
63
x, mask = self.encode(x)
64
x = self.proj(x)
65
logits = self.logits(x)
66
return logits, mask
67
68
69
class HubertSoft(Hubert):
70
def __init__(self):
71
super().__init__()
72
73
# @torch.inference_mode()
74
def units(self, wav: torch.Tensor) -> torch.Tensor:
75
wav = torch.nn.functional.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
76
x, _ = self.encode(wav)
77
return self.proj(x)
78
79
def forward(self, wav: torch.Tensor):
80
return self.units(wav)
81
82
83
class FeatureExtractor(nn.Module):
84
def __init__(self):
85
super().__init__()
86
self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
87
self.norm0 = nn.GroupNorm(512, 512)
88
self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
89
self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
90
self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
91
self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
92
self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
93
self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
94
95
def forward(self, x: torch.Tensor) -> torch.Tensor:
96
x = t_func.gelu(self.norm0(self.conv0(x)))
97
x = t_func.gelu(self.conv1(x))
98
x = t_func.gelu(self.conv2(x))
99
x = t_func.gelu(self.conv3(x))
100
x = t_func.gelu(self.conv4(x))
101
x = t_func.gelu(self.conv5(x))
102
x = t_func.gelu(self.conv6(x))
103
return x
104
105
106
class FeatureProjection(nn.Module):
107
def __init__(self):
108
super().__init__()
109
self.norm = nn.LayerNorm(512)
110
self.projection = nn.Linear(512, 768)
111
self.dropout = nn.Dropout(0.1)
112
113
def forward(self, x: torch.Tensor) -> torch.Tensor:
114
x = self.norm(x)
115
x = self.projection(x)
116
x = self.dropout(x)
117
return x
118
119
120
class PositionalConvEmbedding(nn.Module):
121
def __init__(self):
122
super().__init__()
123
self.conv = nn.Conv1d(
124
768,
125
768,
126
kernel_size=128,
127
padding=128 // 2,
128
groups=16,
129
)
130
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
131
132
def forward(self, x: torch.Tensor) -> torch.Tensor:
133
x = self.conv(x.transpose(1, 2))
134
x = t_func.gelu(x[:, :, :-1])
135
return x.transpose(1, 2)
136
137
138
class TransformerEncoder(nn.Module):
139
def __init__(
140
self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
141
) -> None:
142
super(TransformerEncoder, self).__init__()
143
self.layers = nn.ModuleList(
144
[copy.deepcopy(encoder_layer) for _ in range(num_layers)]
145
)
146
self.num_layers = num_layers
147
148
def forward(
149
self,
150
src: torch.Tensor,
151
mask: torch.Tensor = None,
152
src_key_padding_mask: torch.Tensor = None,
153
output_layer: Optional[int] = None,
154
) -> torch.Tensor:
155
output = src
156
for layer in self.layers[:output_layer]:
157
output = layer(
158
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
159
)
160
return output
161
162
163
def _compute_mask(
164
shape: Tuple[int, int],
165
mask_prob: float,
166
mask_length: int,
167
device: torch.device,
168
min_masks: int = 0,
169
) -> torch.Tensor:
170
batch_size, sequence_length = shape
171
172
if mask_length < 1:
173
raise ValueError("`mask_length` has to be bigger than 0.")
174
175
if mask_length > sequence_length:
176
raise ValueError(
177
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
178
)
179
180
# compute number of masked spans in batch
181
num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
182
num_masked_spans = max(num_masked_spans, min_masks)
183
184
# make sure num masked indices <= sequence_length
185
if num_masked_spans * mask_length > sequence_length:
186
num_masked_spans = sequence_length // mask_length
187
188
# SpecAugment mask to fill
189
mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
190
191
# uniform distribution to sample from, make sure that offset samples are < sequence_length
192
uniform_dist = torch.ones(
193
(batch_size, sequence_length - (mask_length - 1)), device=device
194
)
195
196
# get random indices to mask
197
mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
198
199
# expand masked indices to masked spans
200
mask_indices = (
201
mask_indices.unsqueeze(dim=-1)
202
.expand((batch_size, num_masked_spans, mask_length))
203
.reshape(batch_size, num_masked_spans * mask_length)
204
)
205
offsets = (
206
torch.arange(mask_length, device=device)[None, None, :]
207
.expand((batch_size, num_masked_spans, mask_length))
208
.reshape(batch_size, num_masked_spans * mask_length)
209
)
210
mask_idxs = mask_indices + offsets
211
212
# scatter indices to mask
213
mask = mask.scatter(1, mask_idxs, True)
214
215
return mask
216
217
218
def hubert_soft(
219
path: str
220
) -> HubertSoft:
221
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
222
Args:
223
path (str): path of a pretrained model
224
"""
225
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
hubert = HubertSoft()
227
checkpoint = torch.load(path)
228
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
229
hubert.load_state_dict(checkpoint)
230
hubert.eval().to(dev)
231
return hubert
232
233
234
def get_units(hbt_soft, raw_wav_path, dev=torch.device('cuda')):
235
wav, sr = librosa.load(raw_wav_path, sr=None)
236
assert (sr >= 16000)
237
if len(wav.shape) > 1:
238
wav = librosa.to_mono(wav)
239
if sr != 16000:
240
wav16 = librosa.resample(wav, sr, 16000)
241
else:
242
wav16 = wav
243
dev = torch.device("cuda" if (dev == torch.device('cuda') and torch.cuda.is_available()) else "cpu")
244
torch.cuda.is_available() and torch.cuda.empty_cache()
245
with torch.inference_mode():
246
units = hbt_soft.units(torch.FloatTensor(wav16.astype(float)).unsqueeze(0).unsqueeze(0).to(dev))
247
return units
248
249
250
def get_end_file(dir_path, end):
251
file_list = []
252
for root, dirs, files in os.walk(dir_path):
253
files = [f for f in files if f[0] != '.']
254
dirs[:] = [d for d in dirs if d[0] != '.']
255
for f_file in files:
256
if f_file.endswith(end):
257
file_list.append(os.path.join(root, f_file).replace("\\", "/"))
258
return file_list
259
260
261
if __name__ == '__main__':
262
from pathlib import Path
263
264
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
265
# hubert的模型路径
266
hbt_model = hubert_soft(str(list(Path(hparams['hubert_path']).home().rglob('*.pt'))[0]))
267
# 这个不用改,自动在根目录下所有wav的同文件夹生成其对应的npy
268
file_lists = list(Path(hparams['raw_data_dir']).rglob('*.wav'))
269
nums = len(file_lists)
270
count = 0
271
for wav_path in file_lists:
272
npy_path = wav_path.with_suffix(".npy")
273
npy_content = get_units(hbt_model, wav_path).cpu().numpy()[0]
274
np.save(str(npy_path), npy_content)
275
count += 1
276
print(f"hubert process:{round(count * 100 / nums, 2)}%")
277
278