Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/tuning/utils.py
1171 views
1
from sklearn.metrics import roc_auc_score, accuracy_score
2
from torch.utils.data import Dataset
3
import torch.nn as nn
4
from tqdm import tqdm
5
import pandas as pd
6
import numpy as np
7
import torchaudio
8
import warnings
9
import random
10
import torch
11
import gc
12
warnings.filterwarnings('ignore')
13
14
15
def read_audio(path: str,
16
sampling_rate: int = 16000,
17
normalize=False):
18
19
wav, sr = torchaudio.load(path)
20
21
if wav.size(0) > 1:
22
wav = wav.mean(dim=0, keepdim=True)
23
24
if sampling_rate:
25
if sr != sampling_rate:
26
transform = torchaudio.transforms.Resample(orig_freq=sr,
27
new_freq=sampling_rate)
28
wav = transform(wav)
29
sr = sampling_rate
30
31
if normalize and wav.abs().max() != 0:
32
wav = wav / wav.abs().max()
33
34
return wav.squeeze(0)
35
36
37
def build_audiomentations_augs(p):
38
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
39
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
40
Aliasing, AddGaussianNoise
41
transforms = [Aliasing(p=1),
42
AddGaussianNoise(p=1),
43
AirAbsorption(p=1),
44
BandPassFilter(p=1),
45
BandStopFilter(p=1),
46
ClippingDistortion(p=1),
47
HighPassFilter(p=1),
48
HighShelfFilter(p=1),
49
LowPassFilter(p=1),
50
LowShelfFilter(p=1),
51
Mp3Compression(p=1),
52
PeakingFilter(p=1),
53
PitchShift(p=1),
54
RoomSimulator(p=1, leave_length_unchanged=True),
55
SevenBandParametricEQ(p=1)]
56
tr = SomeOf((1, 3), transforms=transforms, p=p)
57
return tr
58
59
60
class SileroVadDataset(Dataset):
61
def __init__(self,
62
config,
63
mode='train'):
64
65
self.num_samples = 512 # constant, do not change
66
self.sr = 16000 # constant, do not change
67
68
self.resample_to_8k = config.tune_8k
69
self.noise_loss = config.noise_loss
70
self.max_train_length_sec = config.max_train_length_sec
71
self.max_train_length_samples = config.max_train_length_sec * self.sr
72
73
assert self.max_train_length_samples % self.num_samples == 0
74
assert mode in ['train', 'val']
75
76
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
77
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
78
self.index_dict = self.dataframe.to_dict('index')
79
self.mode = mode
80
print(f'DATASET SIZE : {len(self.dataframe)}')
81
82
if mode == 'train':
83
self.augs = build_audiomentations_augs(p=config.aug_prob)
84
else:
85
self.augs = None
86
87
def __getitem__(self, idx):
88
idx = None if self.mode == 'train' else idx
89
wav, gt, mask = self.load_speech_sample(idx)
90
91
if self.mode == 'train':
92
wav = self.add_augs(wav)
93
if len(wav) > self.max_train_length_samples:
94
wav = wav[:self.max_train_length_samples]
95
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
96
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
97
98
wav = torch.FloatTensor(wav)
99
if self.resample_to_8k:
100
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
101
new_freq=8000)
102
wav = transform(wav)
103
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
104
105
def __len__(self):
106
return len(self.index_dict)
107
108
def load_speech_sample(self, idx=None):
109
if idx is None:
110
idx = random.randint(0, len(self.index_dict) - 1)
111
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
112
113
if len(wav) % self.num_samples != 0:
114
pad_num = self.num_samples - (len(wav) % (self.num_samples))
115
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
116
117
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
118
119
assert len(gt) == len(wav) / self.num_samples
120
121
mask[gt == 0]
122
123
return wav, gt, mask
124
125
def get_ground_truth_annotated(self, annotation, audio_length_samples):
126
gt = np.zeros(audio_length_samples)
127
128
for i in annotation:
129
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
130
131
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
132
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
133
mask = np.ones(len(squeezed_predicts))
134
mask[squeezed_predicts == 0] = self.noise_loss
135
return squeezed_predicts, mask
136
137
def add_augs(self, wav):
138
while True:
139
try:
140
wav_aug = self.augs(wav, self.sr)
141
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
142
return wav
143
return wav_aug
144
except Exception as e:
145
continue
146
147
148
def SileroVadPadder(batch):
149
wavs = [batch[i][0] for i in range(len(batch))]
150
labels = [batch[i][1] for i in range(len(batch))]
151
masks = [batch[i][2] for i in range(len(batch))]
152
153
wavs = torch.nn.utils.rnn.pad_sequence(
154
wavs, batch_first=True, padding_value=0)
155
156
labels = torch.nn.utils.rnn.pad_sequence(
157
labels, batch_first=True, padding_value=0)
158
159
masks = torch.nn.utils.rnn.pad_sequence(
160
masks, batch_first=True, padding_value=0)
161
162
return wavs, labels, masks
163
164
165
class VADDecoderRNNJIT(nn.Module):
166
167
def __init__(self):
168
super(VADDecoderRNNJIT, self).__init__()
169
170
self.rnn = nn.LSTMCell(128, 128)
171
self.decoder = nn.Sequential(nn.Dropout(0.1),
172
nn.ReLU(),
173
nn.Conv1d(128, 1, kernel_size=1),
174
nn.Sigmoid())
175
176
def forward(self, x, state=torch.zeros(0)):
177
x = x.squeeze(-1)
178
if len(state):
179
h, c = self.rnn(x, (state[0], state[1]))
180
else:
181
h, c = self.rnn(x)
182
183
x = h.unsqueeze(-1).float()
184
state = torch.stack([h, c])
185
x = self.decoder(x)
186
return x, state
187
188
189
class AverageMeter(object):
190
"""Computes and stores the average and current value"""
191
192
def __init__(self):
193
self.reset()
194
195
def reset(self):
196
self.val = 0
197
self.avg = 0
198
self.sum = 0
199
self.count = 0
200
201
def update(self, val, n=1):
202
self.val = val
203
self.sum += val * n
204
self.count += n
205
self.avg = self.sum / self.count
206
207
208
def train(config,
209
loader,
210
jit_model,
211
decoder,
212
criterion,
213
optimizer,
214
device):
215
216
losses = AverageMeter()
217
decoder.train()
218
219
context_size = 32 if config.tune_8k else 64
220
num_samples = 256 if config.tune_8k else 512
221
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
222
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
223
224
with torch.enable_grad():
225
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
226
targets = targets.to(device)
227
x = x.to(device)
228
masks = masks.to(device)
229
x = torch.nn.functional.pad(x, (context_size, 0))
230
231
outs = []
232
state = torch.zeros(0)
233
for i in range(context_size, x.shape[1], num_samples):
234
input_ = x[:, i-context_size:i+num_samples]
235
out = stft_layer(input_)
236
out = encoder_layer(out)
237
out, state = decoder(out, state)
238
outs.append(out)
239
stacked = torch.cat(outs, dim=2).squeeze(1)
240
241
loss = criterion(stacked, targets)
242
loss = (loss * masks).mean()
243
loss.backward()
244
optimizer.step()
245
losses.update(loss.item(), masks.numel())
246
247
torch.cuda.empty_cache()
248
gc.collect()
249
250
return losses.avg
251
252
253
def validate(config,
254
loader,
255
jit_model,
256
decoder,
257
criterion,
258
device):
259
260
losses = AverageMeter()
261
decoder.eval()
262
263
predicts = []
264
gts = []
265
266
context_size = 32 if config.tune_8k else 64
267
num_samples = 256 if config.tune_8k else 512
268
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
269
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
270
271
with torch.no_grad():
272
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
273
targets = targets.to(device)
274
x = x.to(device)
275
masks = masks.to(device)
276
x = torch.nn.functional.pad(x, (context_size, 0))
277
278
outs = []
279
state = torch.zeros(0)
280
for i in range(context_size, x.shape[1], num_samples):
281
input_ = x[:, i-context_size:i+num_samples]
282
out = stft_layer(input_)
283
out = encoder_layer(out)
284
out, state = decoder(out, state)
285
outs.append(out)
286
stacked = torch.cat(outs, dim=2).squeeze(1)
287
288
predicts.extend(stacked[masks != 0].tolist())
289
gts.extend(targets[masks != 0].tolist())
290
291
loss = criterion(stacked, targets)
292
loss = (loss * masks).mean()
293
losses.update(loss.item(), masks.numel())
294
score = roc_auc_score(gts, predicts)
295
296
torch.cuda.empty_cache()
297
gc.collect()
298
299
return losses.avg, round(score, 3)
300
301
302
def init_jit_model(model_path: str,
303
device=torch.device('cpu')):
304
torch.set_grad_enabled(False)
305
model = torch.jit.load(model_path, map_location=device)
306
model.eval()
307
return model
308
309
310
def predict(model, loader, device, sr):
311
with torch.no_grad():
312
all_predicts = []
313
all_gts = []
314
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
315
x = x.to(device)
316
out = model.audio_forward(x, sr=sr)
317
318
for i, out_chunk in enumerate(out):
319
predict = out_chunk[masks[i] != 0].cpu().tolist()
320
gt = targets[i, masks[i] != 0].cpu().tolist()
321
322
all_predicts.append(predict)
323
all_gts.append(gt)
324
return all_predicts, all_gts
325
326
327
def calculate_best_thresholds(all_predicts, all_gts):
328
best_acc = 0
329
for ths_enter in tqdm(np.linspace(0, 1, 20)):
330
for ths_exit in np.linspace(0, 1, 20):
331
if ths_exit >= ths_enter:
332
continue
333
334
accs = []
335
for j, predict in enumerate(all_predicts):
336
predict_bool = []
337
is_speech = False
338
for i in predict:
339
if i >= ths_enter:
340
is_speech = True
341
predict_bool.append(1)
342
elif i <= ths_exit:
343
is_speech = False
344
predict_bool.append(0)
345
else:
346
val = 1 if is_speech else 0
347
predict_bool.append(val)
348
349
score = round(accuracy_score(all_gts[j], predict_bool), 4)
350
accs.append(score)
351
352
mean_acc = round(np.mean(accs), 3)
353
if mean_acc > best_acc:
354
best_acc = mean_acc
355
best_ths_enter = round(ths_enter, 2)
356
best_ths_exit = round(ths_exit, 2)
357
return best_ths_enter, best_ths_exit, best_acc
358
359