Path: blob/master/examples/rust-example/src/silero.rs
1903 views
use crate::utils;1use ndarray::{Array, Array1, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};2use ort::session::Session;3use ort::value::Value;4use std::mem::take;5use std::path::Path;67#[derive(Debug)]8pub struct Silero {9session: Session,10sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,11state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,12context: Array1<f32>,13context_size: usize,14}1516impl Silero {17pub fn new(18sample_rate: utils::SampleRate,19model_path: impl AsRef<Path>,20) -> Result<Self, ort::Error> {21let session = Session::builder()?.commit_from_file(model_path)?;22let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());23let sample_rate_val: i64 = sample_rate.into();24let context_size = if sample_rate_val == 16000 { 64 } else { 32 };25let context = Array1::<f32>::zeros(context_size);26let sample_rate = Array::from_shape_vec([1], vec![sample_rate_val]).unwrap();27Ok(Self {28session,29sample_rate,30state,31context,32context_size,33})34}3536pub fn reset(&mut self) {37self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());38self.context = Array1::<f32>::zeros(self.context_size);39}4041pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {42let data = audio_frame43.iter()44.map(|x| (*x as f32) / (i16::MAX as f32))45.collect::<Vec<_>>();4647// Concatenate context with input48let mut input_with_context = Vec::with_capacity(self.context_size + data.len());49input_with_context.extend_from_slice(self.context.as_slice().unwrap());50input_with_context.extend_from_slice(&data);5152let frame =53Array2::<f32>::from_shape_vec([1, input_with_context.len()], input_with_context)54.unwrap();5556let frame_value = Value::from_array(frame)?;57let state_value = Value::from_array(take(&mut self.state))?;58let sr_value = Value::from_array(self.sample_rate.clone())?;5960let res = self.session.run([61(&frame_value).into(),62(&state_value).into(),63(&sr_value).into(),64])?;6566let (shape, state_data) = res["stateN"].try_extract_tensor::<f32>()?;67let shape_usize: Vec<usize> = shape.as_ref().iter().map(|&d| d as usize).collect();68self.state = ArrayD::from_shape_vec(shape_usize.as_slice(), state_data.to_vec()).unwrap();6970// Update context with last context_size samples from the input71if data.len() >= self.context_size {72self.context = Array1::from_vec(data[data.len() - self.context_size..].to_vec());73}7475let prob = *res["output"]76.try_extract_tensor::<f32>()77.unwrap()78.179.first()80.unwrap();81Ok(prob)82}83}848586