Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/examples/rust-example/src/silero.rs
1903 views
1
use crate::utils;
2
use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
3
use ort::session::Session;
4
use ort::value::Value;
5
use std::mem::take;
6
use std::path::Path;
7
8
#[derive(Debug)]
9
pub struct Silero {
10
session: Session,
11
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
12
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
13
context: Array1<f32>,
14
context_size: usize,
15
}
16
17
impl Silero {
18
pub fn new(
19
sample_rate: utils::SampleRate,
20
model_path: impl AsRef<Path>,
21
) -> Result<Self, ort::Error> {
22
let session = Session::builder()?.commit_from_file(model_path)?;
23
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
24
let sample_rate_val: i64 = sample_rate.into();
25
let context_size = if sample_rate_val == 16000 { 64 } else { 32 };
26
let context = Array1::<f32>::zeros(context_size);
27
let sample_rate = Array::from_shape_vec([1], vec![sample_rate_val]).unwrap();
28
Ok(Self {
29
session,
30
sample_rate,
31
state,
32
context,
33
context_size,
34
})
35
}
36
37
pub fn reset(&mut self) {
38
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
39
self.context = Array1::<f32>::zeros(self.context_size);
40
}
41
42
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
43
let data = audio_frame
44
.iter()
45
.map(|x| (*x as f32) / (i16::MAX as f32))
46
.collect::<Vec<_>>();
47
48
// Concatenate context with input
49
let mut input_with_context = Vec::with_capacity(self.context_size + data.len());
50
input_with_context.extend_from_slice(self.context.as_slice().unwrap());
51
input_with_context.extend_from_slice(&data);
52
53
let frame =
54
Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context)
55
.unwrap();
56
57
let frame_value = Value::from_array(frame)?;
58
let state_value = Value::from_array(take(&mut self.state))?;
59
let sr_value = Value::from_array(self.sample_rate.clone())?;
60
61
let res = self.session.run([
62
(&frame_value).into(),
63
(&state_value).into(),
64
(&sr_value).into(),
65
])?;
66
67
let (shape, state_data) = res["stateN"].try_extract_tensor::<f32>()?;
68
let shape_usize: Vec<usize> = shape.as_ref().iter().map(|&d| d as usize).collect();
69
self.state = ArrayD::from_shape_vec(shape_usize.as_slice(), state_data.to_vec()).unwrap();
70
71
// Update context with last context_size samples from the input
72
if data.len() >= self.context_size {
73
self.context = Array1::from_vec(data[data.len() - self.context_size..].to_vec());
74
}
75
76
let prob = *res["output"]
77
.try_extract_tensor::<f32>()
78
.unwrap()
79
.1
80
.first()
81
.unwrap();
82
Ok(prob)
83
}
84
}
85
86