Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/tuning/search_thresholds.py
1171 views
1
from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
2
from omegaconf import OmegaConf
3
import torch
4
torch.set_num_threads(1)
5
6
if __name__ == '__main__':
7
config = OmegaConf.load('config.yml')
8
9
loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
10
batch_size=config.batch_size,
11
collate_fn=SileroVadPadder,
12
num_workers=config.num_workers)
13
14
if config.jit_model_path:
15
print(f'Loading model from the local folder: {config.jit_model_path}')
16
model = init_jit_model(config.jit_model_path, device=config.device)
17
else:
18
if config.use_torchhub:
19
print('Loading model using torch.hub')
20
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
21
model='silero_vad',
22
onnx=False,
23
force_reload=True)
24
else:
25
print('Loading model using silero-vad library')
26
from silero_vad import load_silero_vad
27
model = load_silero_vad(onnx=False)
28
29
print('Model loaded')
30
model.to(config.device)
31
32
print('Making predicts...')
33
all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
34
print('Calculating thresholds...')
35
best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
36
print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')
37
38