from sklearn.metrics import roc_auc_score, accuracy_score
from torch.utils.data import Dataset
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
import numpy as np
import torchaudio
import warnings
import random
import torch
import gc
warnings.filterwarnings('ignore')
def read_audio(path: str,
sampling_rate: int = 16000,
normalize=False):
wav, sr = torchaudio.load(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sampling_rate:
if sr != sampling_rate:
transform = torchaudio.transforms.Resample(orig_freq=sr,
new_freq=sampling_rate)
wav = transform(wav)
sr = sampling_rate
if normalize and wav.abs().max() != 0:
wav = wav / wav.abs().max()
return wav.squeeze(0)
def build_audiomentations_augs(p):
from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
Aliasing, AddGaussianNoise
transforms = [Aliasing(p=1),
AddGaussianNoise(p=1),
AirAbsorption(p=1),
BandPassFilter(p=1),
BandStopFilter(p=1),
ClippingDistortion(p=1),
HighPassFilter(p=1),
HighShelfFilter(p=1),
LowPassFilter(p=1),
LowShelfFilter(p=1),
Mp3Compression(p=1),
PeakingFilter(p=1),
PitchShift(p=1),
RoomSimulator(p=1, leave_length_unchanged=True),
SevenBandParametricEQ(p=1)]
tr = SomeOf((1, 3), transforms=transforms, p=p)
return tr
class SileroVadDataset(Dataset):
def __init__(self,
config,
mode='train'):
self.num_samples = 512
self.sr = 16000
self.resample_to_8k = config.tune_8k
self.noise_loss = config.noise_loss
self.max_train_length_sec = config.max_train_length_sec
self.max_train_length_samples = config.max_train_length_sec * self.sr
assert self.max_train_length_samples % self.num_samples == 0
assert mode in ['train', 'val']
dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
self.index_dict = self.dataframe.to_dict('index')
self.mode = mode
print(f'DATASET SIZE : {len(self.dataframe)}')
if mode == 'train':
self.augs = build_audiomentations_augs(p=config.aug_prob)
else:
self.augs = None
def __getitem__(self, idx):
idx = None if self.mode == 'train' else idx
wav, gt, mask = self.load_speech_sample(idx)
if self.mode == 'train':
wav = self.add_augs(wav)
if len(wav) > self.max_train_length_samples:
wav = wav[:self.max_train_length_samples]
gt = gt[:int(self.max_train_length_samples / self.num_samples)]
mask = mask[:int(self.max_train_length_samples / self.num_samples)]
wav = torch.FloatTensor(wav)
if self.resample_to_8k:
transform = torchaudio.transforms.Resample(orig_freq=self.sr,
new_freq=8000)
wav = transform(wav)
return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
def __len__(self):
return len(self.index_dict)
def load_speech_sample(self, idx=None):
if idx is None:
idx = random.randint(0, len(self.index_dict) - 1)
wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
if len(wav) % self.num_samples != 0:
pad_num = self.num_samples - (len(wav) % (self.num_samples))
wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
assert len(gt) == len(wav) / self.num_samples
mask[gt == 0]
return wav, gt, mask
def get_ground_truth_annotated(self, annotation, audio_length_samples):
gt = np.zeros(audio_length_samples)
for i in annotation:
gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
mask = np.ones(len(squeezed_predicts))
mask[squeezed_predicts == 0] = self.noise_loss
return squeezed_predicts, mask
def add_augs(self, wav):
while True:
try:
wav_aug = self.augs(wav, self.sr)
if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
return wav
return wav_aug
except Exception as e:
continue
def SileroVadPadder(batch):
wavs = [batch[i][0] for i in range(len(batch))]
labels = [batch[i][1] for i in range(len(batch))]
masks = [batch[i][2] for i in range(len(batch))]
wavs = torch.nn.utils.rnn.pad_sequence(
wavs, batch_first=True, padding_value=0)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=0)
masks = torch.nn.utils.rnn.pad_sequence(
masks, batch_first=True, padding_value=0)
return wavs, labels, masks
class VADDecoderRNNJIT(nn.Module):
def __init__(self):
super(VADDecoderRNNJIT, self).__init__()
self.rnn = nn.LSTMCell(128, 128)
self.decoder = nn.Sequential(nn.Dropout(0.1),
nn.ReLU(),
nn.Conv1d(128, 1, kernel_size=1),
nn.Sigmoid())
def forward(self, x, state=torch.zeros(0)):
x = x.squeeze(-1)
if len(state):
h, c = self.rnn(x, (state[0], state[1]))
else:
h, c = self.rnn(x)
x = h.unsqueeze(-1).float()
state = torch.stack([h, c])
x = self.decoder(x)
return x, state
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def train(config,
loader,
jit_model,
decoder,
criterion,
optimizer,
device):
losses = AverageMeter()
decoder.train()
context_size = 32 if config.tune_8k else 64
num_samples = 256 if config.tune_8k else 512
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
with torch.enable_grad():
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
targets = targets.to(device)
x = x.to(device)
masks = masks.to(device)
x = torch.nn.functional.pad(x, (context_size, 0))
outs = []
state = torch.zeros(0)
for i in range(context_size, x.shape[1], num_samples):
input_ = x[:, i-context_size:i+num_samples]
out = stft_layer(input_)
out = encoder_layer(out)
out, state = decoder(out, state)
outs.append(out)
stacked = torch.cat(outs, dim=2).squeeze(1)
loss = criterion(stacked, targets)
loss = (loss * masks).mean()
loss.backward()
optimizer.step()
losses.update(loss.item(), masks.numel())
torch.cuda.empty_cache()
gc.collect()
return losses.avg
def validate(config,
loader,
jit_model,
decoder,
criterion,
device):
losses = AverageMeter()
decoder.eval()
predicts = []
gts = []
context_size = 32 if config.tune_8k else 64
num_samples = 256 if config.tune_8k else 512
stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
with torch.no_grad():
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
targets = targets.to(device)
x = x.to(device)
masks = masks.to(device)
x = torch.nn.functional.pad(x, (context_size, 0))
outs = []
state = torch.zeros(0)
for i in range(context_size, x.shape[1], num_samples):
input_ = x[:, i-context_size:i+num_samples]
out = stft_layer(input_)
out = encoder_layer(out)
out, state = decoder(out, state)
outs.append(out)
stacked = torch.cat(outs, dim=2).squeeze(1)
predicts.extend(stacked[masks != 0].tolist())
gts.extend(targets[masks != 0].tolist())
loss = criterion(stacked, targets)
loss = (loss * masks).mean()
losses.update(loss.item(), masks.numel())
score = roc_auc_score(gts, predicts)
torch.cuda.empty_cache()
gc.collect()
return losses.avg, round(score, 3)
def init_jit_model(model_path: str,
device=torch.device('cpu')):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
return model
def predict(model, loader, device, sr):
with torch.no_grad():
all_predicts = []
all_gts = []
for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
x = x.to(device)
out = model.audio_forward(x, sr=sr)
for i, out_chunk in enumerate(out):
predict = out_chunk[masks[i] != 0].cpu().tolist()
gt = targets[i, masks[i] != 0].cpu().tolist()
all_predicts.append(predict)
all_gts.append(gt)
return all_predicts, all_gts
def calculate_best_thresholds(all_predicts, all_gts):
best_acc = 0
for ths_enter in tqdm(np.linspace(0, 1, 20)):
for ths_exit in np.linspace(0, 1, 20):
if ths_exit >= ths_enter:
continue
accs = []
for j, predict in enumerate(all_predicts):
predict_bool = []
is_speech = False
for i in predict:
if i >= ths_enter:
is_speech = True
predict_bool.append(1)
elif i <= ths_exit:
is_speech = False
predict_bool.append(0)
else:
val = 1 if is_speech else 0
predict_bool.append(val)
score = round(accuracy_score(all_gts[j], predict_bool), 4)
accs.append(score)
mean_acc = round(np.mean(accs), 3)
if mean_acc > best_acc:
best_acc = mean_acc
best_ths_enter = round(ths_enter, 2)
best_ths_exit = round(ths_exit, 2)
return best_ths_enter, best_ths_exit, best_acc