Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/examples/rust-example/src/silero.rs
1171 views
1
use crate::utils;
2
use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
3
use std::path::Path;
4
5
#[derive(Debug)]
6
pub struct Silero {
7
session: ort::Session,
8
sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,
9
state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,
10
}
11
12
impl Silero {
13
pub fn new(
14
sample_rate: utils::SampleRate,
15
model_path: impl AsRef<Path>,
16
) -> Result<Self, ort::Error> {
17
let session = ort::Session::builder()?.commit_from_file(model_path)?;
18
let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
19
let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
20
Ok(Self {
21
session,
22
sample_rate,
23
state,
24
})
25
}
26
27
pub fn reset(&mut self) {
28
self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());
29
}
30
31
pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {
32
let data = audio_frame
33
.iter()
34
.map(|x| (*x as f32) / (i16::MAX as f32))
35
.collect::<Vec<_>>();
36
let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();
37
frame = frame.slice(s![.., ..480]).to_owned();
38
let inps = ort::inputs![
39
frame,
40
std::mem::take(&mut self.state),
41
self.sample_rate.clone(),
42
]?;
43
let res = self
44
.session
45
.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
46
self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
47
Ok(*res["output"]
48
.try_extract_raw_tensor::<f32>()
49
.unwrap()
50
.1
51
.first()
52
.unwrap())
53
}
54
}
55
56