Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/src/silero_vad/utils_vad.py
1918 views
1
import torch
2
import torchaudio
3
from typing import Callable, List
4
import warnings
5
from packaging import version
6
7
languages = ['ru', 'en', 'de', 'es']
8
9
10
class OnnxWrapper():
11
12
def __init__(self, path, force_onnx_cpu=False):
13
import numpy as np
14
global np
15
import onnxruntime
16
17
opts = onnxruntime.SessionOptions()
18
opts.inter_op_num_threads = 1
19
opts.intra_op_num_threads = 1
20
21
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
22
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
23
else:
24
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
25
26
self.reset_states()
27
if '16k' in path:
28
warnings.warn('This model support only 16000 sampling rate!')
29
self.sample_rates = [16000]
30
else:
31
self.sample_rates = [8000, 16000]
32
33
def _validate_input(self, x, sr: int):
34
if x.dim() == 1:
35
x = x.unsqueeze(0)
36
if x.dim() > 2:
37
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
38
39
if sr != 16000 and (sr % 16000 == 0):
40
step = sr // 16000
41
x = x[:,::step]
42
sr = 16000
43
44
if sr not in self.sample_rates:
45
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
46
if sr / x.shape[1] > 31.25:
47
raise ValueError("Input audio chunk is too short")
48
49
return x, sr
50
51
def reset_states(self, batch_size=1):
52
self._state = torch.zeros((2, batch_size, 128)).float()
53
self._context = torch.zeros(0)
54
self._last_sr = 0
55
self._last_batch_size = 0
56
57
def __call__(self, x, sr: int):
58
59
x, sr = self._validate_input(x, sr)
60
num_samples = 512 if sr == 16000 else 256
61
62
if x.shape[-1] != num_samples:
63
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
64
65
batch_size = x.shape[0]
66
context_size = 64 if sr == 16000 else 32
67
68
if not self._last_batch_size:
69
self.reset_states(batch_size)
70
if (self._last_sr) and (self._last_sr != sr):
71
self.reset_states(batch_size)
72
if (self._last_batch_size) and (self._last_batch_size != batch_size):
73
self.reset_states(batch_size)
74
75
if not len(self._context):
76
self._context = torch.zeros(batch_size, context_size)
77
78
x = torch.cat([self._context, x], dim=1)
79
if sr in [8000, 16000]:
80
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
81
ort_outs = self.session.run(None, ort_inputs)
82
out, state = ort_outs
83
self._state = torch.from_numpy(state)
84
else:
85
raise ValueError()
86
87
self._context = x[..., -context_size:]
88
self._last_sr = sr
89
self._last_batch_size = batch_size
90
91
out = torch.from_numpy(out)
92
return out
93
94
def audio_forward(self, x, sr: int):
95
outs = []
96
x, sr = self._validate_input(x, sr)
97
self.reset_states()
98
num_samples = 512 if sr == 16000 else 256
99
100
if x.shape[1] % num_samples:
101
pad_num = num_samples - (x.shape[1] % num_samples)
102
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
103
104
for i in range(0, x.shape[1], num_samples):
105
wavs_batch = x[:, i:i+num_samples]
106
out_chunk = self.__call__(wavs_batch, sr)
107
outs.append(out_chunk)
108
109
stacked = torch.cat(outs, dim=1)
110
return stacked.cpu()
111
112
113
class Validator():
114
def __init__(self, url, force_onnx_cpu):
115
self.onnx = True if url.endswith('.onnx') else False
116
torch.hub.download_url_to_file(url, 'inf.model')
117
if self.onnx:
118
import onnxruntime
119
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
120
self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
121
else:
122
self.model = onnxruntime.InferenceSession('inf.model')
123
else:
124
self.model = init_jit_model(model_path='inf.model')
125
126
def __call__(self, inputs: torch.Tensor):
127
with torch.no_grad():
128
if self.onnx:
129
ort_inputs = {'input': inputs.cpu().numpy()}
130
outs = self.model.run(None, ort_inputs)
131
outs = [torch.Tensor(x) for x in outs]
132
else:
133
outs = self.model(inputs)
134
135
return outs
136
137
138
def read_audio(path: str, sampling_rate: int = 16000) -> torch.Tensor:
139
ta_ver = version.parse(torchaudio.__version__)
140
if ta_ver < version.parse("2.9"):
141
try:
142
effects = [['channels', '1'],['rate', str(sampling_rate)]]
143
wav, sr = torchaudio.sox_effects.apply_effects_file(path, effects=effects)
144
except:
145
wav, sr = torchaudio.load(path)
146
else:
147
try:
148
wav, sr = torchaudio.load(path)
149
except:
150
try:
151
from torchcodec.decoders import AudioDecoder
152
samples = AudioDecoder(path).get_all_samples()
153
wav = samples.data
154
sr = samples.sample_rate
155
except ImportError:
156
raise RuntimeError(
157
f"torchaudio version {torchaudio.__version__} requires torchcodec for audio I/O. "
158
+ "Install torchcodec or pin torchaudio < 2.9"
159
)
160
161
if wav.ndim > 1 and wav.size(0) > 1:
162
wav = wav.mean(dim=0, keepdim=True)
163
164
if sr != sampling_rate:
165
wav = torchaudio.transforms.Resample(sr, sampling_rate)(wav)
166
167
return wav.squeeze(0)
168
169
170
def save_audio(path: str, tensor: torch.Tensor, sampling_rate: int = 16000):
171
tensor = tensor.detach().cpu()
172
if tensor.ndim == 1:
173
tensor = tensor.unsqueeze(0)
174
175
ta_ver = version.parse(torchaudio.__version__)
176
177
try:
178
torchaudio.save(path, tensor, sampling_rate, bits_per_sample=16)
179
except Exception:
180
if ta_ver >= version.parse("2.9"):
181
try:
182
from torchcodec.encoders import AudioEncoder
183
encoder = AudioEncoder(tensor, sample_rate=16000)
184
encoder.to_file(path)
185
except ImportError:
186
raise RuntimeError(
187
f"torchaudio version {torchaudio.__version__} requires torchcodec for saving. "
188
+ "Install torchcodec or pin torchaudio < 2.9"
189
)
190
else:
191
raise
192
193
194
def init_jit_model(model_path: str,
195
device=torch.device('cpu')):
196
model = torch.jit.load(model_path, map_location=device)
197
model.eval()
198
return model
199
200
201
def make_visualization(probs, step):
202
import pandas as pd
203
pd.DataFrame({'probs': probs},
204
index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
205
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
206
xlabel='seconds',
207
ylabel='speech probability',
208
colormap='tab20')
209
210
211
@torch.no_grad()
212
def get_speech_timestamps(audio: torch.Tensor,
213
model,
214
threshold: float = 0.5,
215
sampling_rate: int = 16000,
216
min_speech_duration_ms: int = 250,
217
max_speech_duration_s: float = float('inf'),
218
min_silence_duration_ms: int = 100,
219
speech_pad_ms: int = 30,
220
return_seconds: bool = False,
221
time_resolution: int = 1,
222
visualize_probs: bool = False,
223
progress_tracking_callback: Callable[[float], None] = None,
224
neg_threshold: float = None,
225
window_size_samples: int = 512,
226
min_silence_at_max_speech: int = 98,
227
use_max_poss_sil_at_max_speech: bool = True):
228
229
"""
230
This method is used for splitting long audios into speech chunks using silero VAD
231
232
Parameters
233
----------
234
audio: torch.Tensor, one dimensional
235
One dimensional float torch.Tensor, other types are casted to torch if possible
236
237
model: preloaded .jit/.onnx silero VAD model
238
239
threshold: float (default - 0.5)
240
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
241
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
242
243
sampling_rate: int (default - 16000)
244
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
245
246
min_speech_duration_ms: int (default - 250 milliseconds)
247
Final speech chunks shorter min_speech_duration_ms are thrown out
248
249
max_speech_duration_s: int (default - inf)
250
Maximum duration of speech chunks in seconds
251
Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent aggressive cutting.
252
Otherwise, they will be split aggressively just before max_speech_duration_s.
253
254
min_silence_duration_ms: int (default - 100 milliseconds)
255
In the end of each speech chunk wait for min_silence_duration_ms before separating it
256
257
speech_pad_ms: int (default - 30 milliseconds)
258
Final speech chunks are padded by speech_pad_ms each side
259
260
return_seconds: bool (default - False)
261
whether return timestamps in seconds (default - samples)
262
263
time_resolution: bool (default - 1)
264
time resolution of speech coordinates when requested as seconds
265
266
visualize_probs: bool (default - False)
267
whether draw prob hist or not
268
269
progress_tracking_callback: Callable[[float], None] (default - None)
270
callback function taking progress in percents as an argument
271
272
neg_threshold: float (default = threshold - 0.15)
273
Negative threshold (noise or exit threshold). If model's current state is SPEECH, values BELOW this value are considered as NON-SPEECH.
274
275
min_silence_at_max_speech: int (default - 98ms)
276
Minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
277
278
use_max_poss_sil_at_max_speech: bool (default - True)
279
Whether to use the maximum possible silence at max_speech_duration_s or not. If not, the last silence is used.
280
281
window_size_samples: int (default - 512 samples)
282
!!! DEPRECATED, DOES NOTHING !!!
283
284
Returns
285
----------
286
speeches: list of dicts
287
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
288
"""
289
if not torch.is_tensor(audio):
290
try:
291
audio = torch.Tensor(audio)
292
except:
293
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
294
295
if len(audio.shape) > 1:
296
for i in range(len(audio.shape)): # trying to squeeze empty dimensions
297
audio = audio.squeeze(0)
298
if len(audio.shape) > 1:
299
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
300
301
if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
302
step = sampling_rate // 16000
303
sampling_rate = 16000
304
audio = audio[::step]
305
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
306
else:
307
step = 1
308
309
if sampling_rate not in [8000, 16000]:
310
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
311
312
window_size_samples = 512 if sampling_rate == 16000 else 256
313
314
model.reset_states()
315
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
316
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
317
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
318
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
319
min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000
320
321
audio_length_samples = len(audio)
322
323
speech_probs = []
324
for current_start_sample in range(0, audio_length_samples, window_size_samples):
325
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
326
if len(chunk) < window_size_samples:
327
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
328
speech_prob = model(chunk, sampling_rate).item()
329
speech_probs.append(speech_prob)
330
# calculate progress and send it to callback function
331
progress = current_start_sample + window_size_samples
332
if progress > audio_length_samples:
333
progress = audio_length_samples
334
progress_percent = (progress / audio_length_samples) * 100
335
if progress_tracking_callback:
336
progress_tracking_callback(progress_percent)
337
338
triggered = False
339
speeches = []
340
current_speech = {}
341
342
if neg_threshold is None:
343
neg_threshold = max(threshold - 0.15, 0.01)
344
temp_end = 0 # to save potential segment end (and tolerate some silence)
345
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
346
possible_ends = []
347
348
for i, speech_prob in enumerate(speech_probs):
349
cur_sample = window_size_samples * i
350
351
# If speech returns after a temp_end, record candidate silence if long enough and clear temp_end
352
if (speech_prob >= threshold) and temp_end:
353
sil_dur = cur_sample - temp_end
354
if sil_dur > min_silence_samples_at_max_speech:
355
possible_ends.append((temp_end, sil_dur))
356
temp_end = 0
357
if next_start < prev_end:
358
next_start = cur_sample
359
360
# Start of speech
361
if (speech_prob >= threshold) and not triggered:
362
triggered = True
363
current_speech['start'] = cur_sample
364
continue
365
366
# Max speech length reached: decide where to cut
367
if triggered and (cur_sample - current_speech['start'] > max_speech_samples):
368
if use_max_poss_sil_at_max_speech and possible_ends:
369
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
370
current_speech['end'] = prev_end
371
speeches.append(current_speech)
372
current_speech = {}
373
next_start = prev_end + dur
374
375
if next_start < prev_end + cur_sample: # previously reached silence (< neg_thres) and is still not speech (< thres)
376
current_speech['start'] = next_start
377
else:
378
triggered = False
379
prev_end = next_start = temp_end = 0
380
possible_ends = []
381
else:
382
# Legacy max-speech cut (use_max_poss_sil_at_max_speech=False): prefer last valid silence (prev_end) if available
383
if prev_end:
384
current_speech['end'] = prev_end
385
speeches.append(current_speech)
386
current_speech = {}
387
if next_start < prev_end:
388
triggered = False
389
else:
390
current_speech['start'] = next_start
391
prev_end = next_start = temp_end = 0
392
possible_ends = []
393
else:
394
# No prev_end -> fallback to cutting at current sample
395
current_speech['end'] = cur_sample
396
speeches.append(current_speech)
397
current_speech = {}
398
prev_end = next_start = temp_end = 0
399
triggered = False
400
possible_ends = []
401
continue
402
403
# Silence detection while in speech
404
if (speech_prob < neg_threshold) and triggered:
405
if not temp_end:
406
temp_end = cur_sample
407
sil_dur_now = cur_sample - temp_end
408
409
if not use_max_poss_sil_at_max_speech and sil_dur_now > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
410
prev_end = temp_end
411
412
if sil_dur_now < min_silence_samples:
413
continue
414
else:
415
current_speech['end'] = temp_end
416
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
417
speeches.append(current_speech)
418
current_speech = {}
419
prev_end = next_start = temp_end = 0
420
triggered = False
421
possible_ends = []
422
continue
423
424
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
425
current_speech['end'] = audio_length_samples
426
speeches.append(current_speech)
427
428
for i, speech in enumerate(speeches):
429
if i == 0:
430
speech['start'] = int(max(0, speech['start'] - speech_pad_samples))
431
if i != len(speeches) - 1:
432
silence_duration = speeches[i+1]['start'] - speech['end']
433
if silence_duration < 2 * speech_pad_samples:
434
speech['end'] += int(silence_duration // 2)
435
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
436
else:
437
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
438
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
439
else:
440
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
441
442
if return_seconds:
443
audio_length_seconds = audio_length_samples / sampling_rate
444
for speech_dict in speeches:
445
speech_dict['start'] = max(round(speech_dict['start'] / sampling_rate, time_resolution), 0)
446
speech_dict['end'] = min(round(speech_dict['end'] / sampling_rate, time_resolution), audio_length_seconds)
447
elif step > 1:
448
for speech_dict in speeches:
449
speech_dict['start'] *= step
450
speech_dict['end'] *= step
451
452
if visualize_probs:
453
make_visualization(speech_probs, window_size_samples / sampling_rate)
454
455
return speeches
456
457
458
class VADIterator:
459
def __init__(self,
460
model,
461
threshold: float = 0.5,
462
sampling_rate: int = 16000,
463
min_silence_duration_ms: int = 100,
464
speech_pad_ms: int = 30
465
):
466
467
"""
468
Class for stream imitation
469
470
Parameters
471
----------
472
model: preloaded .jit/.onnx silero VAD model
473
474
threshold: float (default - 0.5)
475
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
476
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
477
478
sampling_rate: int (default - 16000)
479
Currently silero VAD models support 8000 and 16000 sample rates
480
481
min_silence_duration_ms: int (default - 100 milliseconds)
482
In the end of each speech chunk wait for min_silence_duration_ms before separating it
483
484
speech_pad_ms: int (default - 30 milliseconds)
485
Final speech chunks are padded by speech_pad_ms each side
486
"""
487
488
self.model = model
489
self.threshold = threshold
490
self.sampling_rate = sampling_rate
491
492
if sampling_rate not in [8000, 16000]:
493
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
494
495
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
496
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
497
self.reset_states()
498
499
def reset_states(self):
500
501
self.model.reset_states()
502
self.triggered = False
503
self.temp_end = 0
504
self.current_sample = 0
505
506
@torch.no_grad()
507
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
508
"""
509
x: torch.Tensor
510
audio chunk (see examples in repo)
511
512
return_seconds: bool (default - False)
513
whether return timestamps in seconds (default - samples)
514
515
time_resolution: int (default - 1)
516
time resolution of speech coordinates when requested as seconds
517
"""
518
519
if not torch.is_tensor(x):
520
try:
521
x = torch.Tensor(x)
522
except:
523
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
524
525
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
526
self.current_sample += window_size_samples
527
528
speech_prob = self.model(x, self.sampling_rate).item()
529
530
if (speech_prob >= self.threshold) and self.temp_end:
531
self.temp_end = 0
532
533
if (speech_prob >= self.threshold) and not self.triggered:
534
self.triggered = True
535
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
536
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
537
538
if (speech_prob < self.threshold - 0.15) and self.triggered:
539
if not self.temp_end:
540
self.temp_end = self.current_sample
541
if self.current_sample - self.temp_end < self.min_silence_samples:
542
return None
543
else:
544
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
545
self.temp_end = 0
546
self.triggered = False
547
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
548
549
return None
550
551
552
def collect_chunks(tss: List[dict],
553
wav: torch.Tensor,
554
seconds: bool = False,
555
sampling_rate: int = None) -> torch.Tensor:
556
"""Collect audio chunks from a longer audio clip
557
558
This method extracts audio chunks from an audio clip, using a list of
559
provided coordinates, and concatenates them together. Coordinates can be
560
passed either as sample numbers or in seconds, in which case the audio
561
sampling rate is also needed.
562
563
Parameters
564
----------
565
tss: List[dict]
566
Coordinate list of the clips to collect from the audio.
567
wav: torch.Tensor, one dimensional
568
One dimensional float torch.Tensor, containing the audio to clip.
569
seconds: bool (default - False)
570
Whether input coordinates are passed as seconds or samples.
571
sampling_rate: int (default - None)
572
Input audio sampling rate. Required if seconds is True.
573
574
Returns
575
-------
576
torch.Tensor, one dimensional
577
One dimensional float torch.Tensor of the concatenated clipped audio
578
chunks.
579
580
Raises
581
------
582
ValueError
583
Raised if sampling_rate is not provided when seconds is True.
584
585
"""
586
if seconds and not sampling_rate:
587
raise ValueError('sampling_rate must be provided when seconds is True')
588
589
chunks = list()
590
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
591
592
for i in _tss:
593
chunks.append(wav[i['start']:i['end']])
594
595
return torch.cat(chunks)
596
597
598
def drop_chunks(tss: List[dict],
599
wav: torch.Tensor,
600
seconds: bool = False,
601
sampling_rate: int = None) -> torch.Tensor:
602
"""Drop audio chunks from a longer audio clip
603
604
This method extracts audio chunks from an audio clip, using a list of
605
provided coordinates, and drops them. Coordinates can be passed either as
606
sample numbers or in seconds, in which case the audio sampling rate is also
607
needed.
608
609
Parameters
610
----------
611
tss: List[dict]
612
Coordinate list of the clips to drop from from the audio.
613
wav: torch.Tensor, one dimensional
614
One dimensional float torch.Tensor, containing the audio to clip.
615
seconds: bool (default - False)
616
Whether input coordinates are passed as seconds or samples.
617
sampling_rate: int (default - None)
618
Input audio sampling rate. Required if seconds is True.
619
620
Returns
621
-------
622
torch.Tensor, one dimensional
623
One dimensional float torch.Tensor of the input audio minus the dropped
624
chunks.
625
626
Raises
627
------
628
ValueError
629
Raised if sampling_rate is not provided when seconds is True.
630
631
"""
632
if seconds and not sampling_rate:
633
raise ValueError('sampling_rate must be provided when seconds is True')
634
635
chunks = list()
636
cur_start = 0
637
638
_tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss
639
640
for i in _tss:
641
chunks.append((wav[cur_start: i['start']]))
642
cur_start = i['end']
643
644
chunks.append(wav[cur_start:])
645
646
return torch.cat(chunks)
647
648
649
def _seconds_to_samples_tss(tss: List[dict], sampling_rate: int) -> List[dict]:
650
"""Convert coordinates expressed in seconds to sample coordinates.
651
"""
652
return [{
653
'start': round(crd['start']) * sampling_rate,
654
'end': round(crd['end']) * sampling_rate
655
} for crd in tss]
656
657