Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/src/silero_vad/utils_vad.py
1171 views
1
import torch
2
import torchaudio
3
from typing import Callable, List
4
import warnings
5
6
languages = ['ru', 'en', 'de', 'es']
7
8
9
class OnnxWrapper():
10
11
def __init__(self, path, force_onnx_cpu=False):
12
import numpy as np
13
global np
14
import onnxruntime
15
16
opts = onnxruntime.SessionOptions()
17
opts.inter_op_num_threads = 1
18
opts.intra_op_num_threads = 1
19
20
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
21
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
22
else:
23
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
24
25
self.reset_states()
26
if '16k' in path:
27
warnings.warn('This model support only 16000 sampling rate!')
28
self.sample_rates = [16000]
29
else:
30
self.sample_rates = [8000, 16000]
31
32
def _validate_input(self, x, sr: int):
33
if x.dim() == 1:
34
x = x.unsqueeze(0)
35
if x.dim() > 2:
36
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
37
38
if sr != 16000 and (sr % 16000 == 0):
39
step = sr // 16000
40
x = x[:,::step]
41
sr = 16000
42
43
if sr not in self.sample_rates:
44
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
45
if sr / x.shape[1] > 31.25:
46
raise ValueError("Input audio chunk is too short")
47
48
return x, sr
49
50
def reset_states(self, batch_size=1):
51
self._state = torch.zeros((2, batch_size, 128)).float()
52
self._context = torch.zeros(0)
53
self._last_sr = 0
54
self._last_batch_size = 0
55
56
def __call__(self, x, sr: int):
57
58
x, sr = self._validate_input(x, sr)
59
num_samples = 512 if sr == 16000 else 256
60
61
if x.shape[-1] != num_samples:
62
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
63
64
batch_size = x.shape[0]
65
context_size = 64 if sr == 16000 else 32
66
67
if not self._last_batch_size:
68
self.reset_states(batch_size)
69
if (self._last_sr) and (self._last_sr != sr):
70
self.reset_states(batch_size)
71
if (self._last_batch_size) and (self._last_batch_size != batch_size):
72
self.reset_states(batch_size)
73
74
if not len(self._context):
75
self._context = torch.zeros(batch_size, context_size)
76
77
x = torch.cat([self._context, x], dim=1)
78
if sr in [8000, 16000]:
79
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
80
ort_outs = self.session.run(None, ort_inputs)
81
out, state = ort_outs
82
self._state = torch.from_numpy(state)
83
else:
84
raise ValueError()
85
86
self._context = x[..., -context_size:]
87
self._last_sr = sr
88
self._last_batch_size = batch_size
89
90
out = torch.from_numpy(out)
91
return out
92
93
def audio_forward(self, x, sr: int):
94
outs = []
95
x, sr = self._validate_input(x, sr)
96
self.reset_states()
97
num_samples = 512 if sr == 16000 else 256
98
99
if x.shape[1] % num_samples:
100
pad_num = num_samples - (x.shape[1] % num_samples)
101
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
102
103
for i in range(0, x.shape[1], num_samples):
104
wavs_batch = x[:, i:i+num_samples]
105
out_chunk = self.__call__(wavs_batch, sr)
106
outs.append(out_chunk)
107
108
stacked = torch.cat(outs, dim=1)
109
return stacked.cpu()
110
111
112
class Validator():
113
def __init__(self, url, force_onnx_cpu):
114
self.onnx = True if url.endswith('.onnx') else False
115
torch.hub.download_url_to_file(url, 'inf.model')
116
if self.onnx:
117
import onnxruntime
118
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
119
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
120
else:
121
self.model = onnxruntime.InferenceSession('inf.model')
122
else:
123
self.model = init_jit_model(model_path='inf.model')
124
125
def __call__(self, inputs: torch.Tensor):
126
with torch.no_grad():
127
if self.onnx:
128
ort_inputs = {'input': inputs.cpu().numpy()}
129
outs = self.model.run(None, ort_inputs)
130
outs = [torch.Tensor(x) for x in outs]
131
else:
132
outs = self.model(inputs)
133
134
return outs
135
136
137
def read_audio(path: str,
138
sampling_rate: int = 16000):
139
list_backends = torchaudio.list_audio_backends()
140
141
assert len(list_backends) > 0, 'The list of available backends is empty, please install backend manually. \
142
\n Recommendations: \n \tSox (UNIX OS) \n \tSoundfile (Windows OS, UNIX OS) \n \tffmpeg (Windows OS, UNIX OS)'
143
144
try:
145
effects = [
146
['channels', '1'],
147
['rate', str(sampling_rate)]
148
]
149
150
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
151
except:
152
wav, sr = torchaudio.load(path)
153
154
if wav.size(0) > 1:
155
wav = wav.mean(dim=0, keepdim=True)
156
157
if sr != sampling_rate:
158
transform = torchaudio.transforms.Resample(orig_freq=sr,
159
new_freq=sampling_rate)
160
wav = transform(wav)
161
sr = sampling_rate
162
163
assert sr == sampling_rate
164
return wav.squeeze(0)
165
166
167
def save_audio(path: str,
168
tensor: torch.Tensor,
169
sampling_rate: int = 16000):
170
torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
171
172
173
def init_jit_model(model_path: str,
174
device=torch.device('cpu')):
175
model = torch.jit.load(model_path, map_location=device)
176
model.eval()
177
return model
178
179
180
def make_visualization(probs, step):
181
import pandas as pd
182
pd.DataFrame({'probs': probs},
183
index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
184
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
185
xlabel='seconds',
186
ylabel='speech probability',
187
colormap='tab20')
188
189
190
@torch.no_grad()
191
def get_speech_timestamps(audio: torch.Tensor,
192
model,
193
threshold: float = 0.5,
194
sampling_rate: int = 16000,
195
min_speech_duration_ms: int = 250,
196
max_speech_duration_s: float = float('inf'),
197
min_silence_duration_ms: int = 100,
198
speech_pad_ms: int = 30,
199
return_seconds: bool = False,
200
time_resolution: int = 1,
201
visualize_probs: bool = False,
202
progress_tracking_callback: Callable[[float], None] = None,
203
neg_threshold: float = None,
204
window_size_samples: int = 512,
205
min_silence_at_max_speech: float = 98,
206
use_max_poss_sil_at_max_speech: bool = True):
207
208
"""
209
This method is used for splitting long audios into speech chunks using silero VAD
210
211
Parameters
212
----------
213
audio: torch.Tensor, one dimensional
214
One dimensional float torch.Tensor, other types are casted to torch if possible
215
216
model: preloaded .jit/.onnx silero VAD model
217
218
threshold: float (default - 0.5)
219
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
220
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
221
222
sampling_rate: int (default - 16000)
223
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
224
225
min_speech_duration_ms: int (default - 250 milliseconds)
226
Final speech chunks shorter min_speech_duration_ms are thrown out
227
228
max_speech_duration_s: int (default - inf)
229
Maximum duration of speech chunks in seconds
230
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting.
231
Otherwise, they will be split aggressively just before max_speech_duration_s.
232
233
min_silence_duration_ms: int (default - 100 milliseconds)
234
In the end of each speech chunk wait for min_silence_duration_ms before separating it
235
236
speech_pad_ms: int (default - 30 milliseconds)
237
Final speech chunks are padded by speech_pad_ms each side
238
239
return_seconds: bool (default - False)
240
whether return timestamps in seconds (default - samples)
241
242
time_resolution: bool (default - 1)
243
time resolution of speech coordinates when requested as seconds
244
245
visualize_probs: bool (default - False)
246
whether draw prob hist or not
247
248
progress_tracking_callback: Callable[[float], None] (default - None)
249
callback function taking progress in percents as an argument
250
251
neg_threshold: float (default = threshold - 0.15)
252
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
253
254
min_silence_at_max_speech: float (default - 98ms)
255
Minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
256
257
use_max_poss_sil_at_max_speech: bool (default - True)
258
Whether to use the maximum possible silence at max_speech_duration_s or not. If not, the last silence is used.
259
260
window_size_samples: int (default - 512 samples)
261
!!! DEPRECATED, DOES NOTHING !!!
262
263
Returns
264
----------
265
speeches: list of dicts
266
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
267
"""
268
if not torch.is_tensor(audio):
269
try:
270
audio = torch.Tensor(audio)
271
except:
272
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
273
274
if len(audio.shape) > 1:
275
for i in range(len(audio.shape)): # trying to squeeze empty dimensions
276
audio = audio.squeeze(0)
277
if len(audio.shape) > 1:
278
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
279
280
if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
281
step = sampling_rate // 16000
282
sampling_rate = 16000
283
audio = audio[::step]
284
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
285
else:
286
step = 1
287
288
if sampling_rate not in [8000, 16000]:
289
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
290
291
window_size_samples = 512 if sampling_rate == 16000 else 256
292
hop_size_samples = int(window_size_samples)
293
294
model.reset_states()
295
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
296
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
297
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
298
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
299
min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000
300
301
audio_length_samples = len(audio)
302
303
speech_probs = []
304
for current_start_sample in range(0, audio_length_samples, hop_size_samples):
305
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
306
if len(chunk) < window_size_samples:
307
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
308
try:
309
speech_prob = model(chunk, sampling_rate).item()
310
except Exception as e:
311
import ipdb; ipdb.set_trace()
312
speech_probs.append(speech_prob)
313
# caculate progress and seng it to callback function
314
progress = current_start_sample + hop_size_samples
315
if progress > audio_length_samples:
316
progress = audio_length_samples
317
progress_percent = (progress / audio_length_samples) * 100
318
if progress_tracking_callback:
319
progress_tracking_callback(progress_percent)
320
321
triggered = False
322
speeches = []
323
current_speech = {}
324
325
if neg_threshold is None:
326
neg_threshold = max(threshold - 0.15, 0.01)
327
temp_end = 0 # to save potential segment end (and tolerate some silence)
328
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
329
possible_ends = []
330
331
for i, speech_prob in enumerate(speech_probs):
332
if (speech_prob >= threshold) and temp_end:
333
if temp_end != 0:
334
sil_dur = (hop_size_samples * i) - temp_end
335
if sil_dur > min_silence_samples_at_max_speech:
336
possible_ends.append((temp_end, sil_dur))
337
temp_end = 0
338
if next_start < prev_end:
339
next_start = hop_size_samples * i
340
341
if (speech_prob >= threshold) and not triggered:
342
triggered = True
343
current_speech['start'] = hop_size_samples * i
344
continue
345
346
if triggered and (hop_size_samples * i) - current_speech['start'] > max_speech_samples:
347
if possible_ends:
348
if use_max_poss_sil_at_max_speech:
349
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
350
else:
351
prev_end, dur = possible_ends[-1] # use the last possible silence segement
352
current_speech['end'] = prev_end
353
speeches.append(current_speech)
354
current_speech = {}
355
next_start = prev_end + dur
356
if next_start < prev_end + hop_size_samples * i: # previously reached silence (< neg_thres) and is still not speech (< thres)
357
#triggered = False
358
current_speech['start'] = next_start
359
else:
360
triggered = False
361
#current_speech['start'] = next_start
362
prev_end = next_start = temp_end = 0
363
possible_ends = []
364
else:
365
current_speech['end'] = hop_size_samples * i
366
speeches.append(current_speech)
367
current_speech = {}
368
prev_end = next_start = temp_end = 0
369
triggered = False
370
possible_ends = []
371
continue
372
373
if (speech_prob < neg_threshold) and triggered:
374
if not temp_end:
375
temp_end = hop_size_samples * i
376
# if ((hop_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
377
# prev_end = temp_end
378
if (hop_size_samples * i) - temp_end < min_silence_samples:
379
continue
380
else:
381
current_speech['end'] = temp_end
382
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
383
speeches.append(current_speech)
384
current_speech = {}
385
prev_end = next_start = temp_end = 0
386
triggered = False
387
possible_ends = []
388
continue
389
390
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
391
current_speech['end'] = audio_length_samples
392
speeches.append(current_speech)
393
394
for i, speech in enumerate(speeches):
395
if i == 0:
396
speech['start'] = int(max(0, speech['start'] - speech_pad_samples))
397
if i != len(speeches) - 1:
398
silence_duration = speeches[i+1]['start'] - speech['end']
399
if silence_duration < 2 * speech_pad_samples:
400
speech['end'] += int(silence_duration // 2)
401
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
402
else:
403
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
404
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
405
else:
406
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
407
408
if return_seconds:
409
audio_length_seconds = audio_length_samples / sampling_rate
410
for speech_dict in speeches:
411
speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, time_resolution), 0)
412
speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, time_resolution), audio_length_seconds)
413
elif step > 1:
414
for speech_dict in speeches:
415
speech_dict['start'] *= step
416
speech_dict['end'] *= step
417
418
if visualize_probs:
419
make_visualization(speech_probs, hop_size_samples / sampling_rate)
420
421
return speeches
422
423
424
class VADIterator:
425
def __init__(self,
426
model,
427
threshold: float = 0.5,
428
sampling_rate: int = 16000,
429
min_silence_duration_ms: int = 100,
430
speech_pad_ms: int = 30
431
):
432
433
"""
434
Class for stream imitation
435
436
Parameters
437
----------
438
model: preloaded .jit/.onnx silero VAD model
439
440
threshold: float (default - 0.5)
441
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
442
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
443
444
sampling_rate: int (default - 16000)
445
Currently silero VAD models support 8000 and 16000 sample rates
446
447
min_silence_duration_ms: int (default - 100 milliseconds)
448
In the end of each speech chunk wait for min_silence_duration_ms before separating it
449
450
speech_pad_ms: int (default - 30 milliseconds)
451
Final speech chunks are padded by speech_pad_ms each side
452
"""
453
454
self.model = model
455
self.threshold = threshold
456
self.sampling_rate = sampling_rate
457
458
if sampling_rate not in [8000, 16000]:
459
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
460
461
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
462
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
463
self.reset_states()
464
465
def reset_states(self):
466
467
self.model.reset_states()
468
self.triggered = False
469
self.temp_end = 0
470
self.current_sample = 0
471
472
@torch.no_grad()
473
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
474
"""
475
x: torch.Tensor
476
audio chunk (see examples in repo)
477
478
return_seconds: bool (default - False)
479
whether return timestamps in seconds (default - samples)
480
481
time_resolution: int (default - 1)
482
time resolution of speech coordinates when requested as seconds
483
"""
484
485
if not torch.is_tensor(x):
486
try:
487
x = torch.Tensor(x)
488
except:
489
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
490
491
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
492
self.current_sample += window_size_samples
493
494
speech_prob = self.model(x, self.sampling_rate).item()
495
496
if (speech_prob >= self.threshold) and self.temp_end:
497
self.temp_end = 0
498
499
if (speech_prob >= self.threshold) and not self.triggered:
500
self.triggered = True
501
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
502
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
503
504
if (speech_prob < self.threshold - 0.15) and self.triggered:
505
if not self.temp_end:
506
self.temp_end = self.current_sample
507
if self.current_sample - self.temp_end < self.min_silence_samples:
508
return None
509
else:
510
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
511
self.temp_end = 0
512
self.triggered = False
513
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
514
515
return None
516
517
518
def collect_chunks(tss: List[dict],
519
wav: torch.Tensor,
520
seconds: bool = False,
521
sampling_rate: int = None) -> torch.Tensor:
522
"""Collect audio chunks from a longer audio clip
523
524
This method extracts audio chunks from an audio clip, using a list of
525
provided coordinates, and concatenates them together. Coordinates can be
526
passed either as sample numbers or in seconds, in which case the audio
527
sampling rate is also needed.
528
529
Parameters
530
----------
531
tss: List[dict]
532
Coordinate list of the clips to collect from the audio.
533
wav: torch.Tensor, one dimensional
534
One dimensional float torch.Tensor, containing the audio to clip.
535
seconds: bool (default - False)
536
Whether input coordinates are passed as seconds or samples.
537
sampling_rate: int (default - None)
538
Input audio sampling rate. Required if seconds is True.
539
540
Returns
541
-------
542
torch.Tensor, one dimensional
543
One dimensional float torch.Tensor of the concatenated clipped audio
544
chunks.
545
546
Raises
547
------
548
ValueError
549
Raised if sampling_rate is not provided when seconds is True.
550
551
"""
552
if seconds and not sampling_rate:
553
raise ValueError('sampling_rate must be provided when seconds is True')
554
555
chunks = list()
556
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
557
558
for i in _tss:
559
chunks.append(wav[i['start']:i['end']])
560
561
return torch.cat(chunks)
562
563
564
def drop_chunks(tss: List[dict],
565
wav: torch.Tensor,
566
seconds: bool = False,
567
sampling_rate: int = None) -> torch.Tensor:
568
"""Drop audio chunks from a longer audio clip
569
570
This method extracts audio chunks from an audio clip, using a list of
571
provided coordinates, and drops them. Coordinates can be passed either as
572
sample numbers or in seconds, in which case the audio sampling rate is also
573
needed.
574
575
Parameters
576
----------
577
tss: List[dict]
578
Coordinate list of the clips to drop from from the audio.
579
wav: torch.Tensor, one dimensional
580
One dimensional float torch.Tensor, containing the audio to clip.
581
seconds: bool (default - False)
582
Whether input coordinates are passed as seconds or samples.
583
sampling_rate: int (default - None)
584
Input audio sampling rate. Required if seconds is True.
585
586
Returns
587
-------
588
torch.Tensor, one dimensional
589
One dimensional float torch.Tensor of the input audio minus the dropped
590
chunks.
591
592
Raises
593
------
594
ValueError
595
Raised if sampling_rate is not provided when seconds is True.
596
597
"""
598
if seconds and not sampling_rate:
599
raise ValueError('sampling_rate must be provided when seconds is True')
600
601
chunks = list()
602
cur_start = 0
603
604
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
605
606
for i in _tss:
607
chunks.append((wav[cur_start: i['start']]))
608
cur_start = i['end']
609
610
return torch.cat(chunks)
611
612
613
def _seconds_to_samples_tss(tss: List[dict], sampling_rate: int) -> List[dict]:
614
"""Convert coordinates expressed in seconds to sample coordinates.
615
"""
616
return [{
617
'start': round(crd['start']) * sampling_rate,
618
'end': round(crd['end']) * sampling_rate
619
} for crd in tss]
620
621