Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/tuning/tune.py
1171 views
1
from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
2
from omegaconf import OmegaConf
3
import torch.nn as nn
4
import torch
5
6
7
if __name__ == '__main__':
8
config = OmegaConf.load('config.yml')
9
10
train_dataset = SileroVadDataset(config, mode='train')
11
train_loader = torch.utils.data.DataLoader(train_dataset,
12
batch_size=config.batch_size,
13
collate_fn=SileroVadPadder,
14
num_workers=config.num_workers)
15
16
val_dataset = SileroVadDataset(config, mode='val')
17
val_loader = torch.utils.data.DataLoader(val_dataset,
18
batch_size=config.batch_size,
19
collate_fn=SileroVadPadder,
20
num_workers=config.num_workers)
21
22
if config.jit_model_path:
23
print(f'Loading model from the local folder: {config.jit_model_path}')
24
model = init_jit_model(config.jit_model_path, device=config.device)
25
else:
26
if config.use_torchhub:
27
print('Loading model using torch.hub')
28
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
29
model='silero_vad',
30
onnx=False,
31
force_reload=True)
32
else:
33
print('Loading model using silero-vad library')
34
from silero_vad import load_silero_vad
35
model = load_silero_vad(onnx=False)
36
37
print('Model loaded')
38
model.to(config.device)
39
decoder = VADDecoderRNNJIT().to(config.device)
40
decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
41
decoder.train()
42
params = decoder.parameters()
43
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params),
44
lr=config.learning_rate)
45
criterion = nn.BCELoss(reduction='none')
46
47
best_val_roc = 0
48
for i in range(config.num_epochs):
49
print(f'Starting epoch {i + 1}')
50
train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device)
51
val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device)
52
print(f'Metrics after epoch {i + 1}:\n'
53
f'\tTrain loss: {round(train_loss, 3)}\n',
54
f'\tValidation loss: {round(val_loss, 3)}\n'
55
f'\tValidation ROC-AUC: {round(val_roc, 3)}')
56
57
if val_roc > best_val_roc:
58
print('New best ROC-AUC, saving model')
59
best_val_roc = val_roc
60
if config.tune_8k:
61
model._model_8k.decoder.load_state_dict(decoder.state_dict())
62
else:
63
model._model.decoder.load_state_dict(decoder.state_dict())
64
torch.jit.save(model, config.model_save_path)
65
print('Done')
66
67