Path: blob/master/examples/rust-example/src/silero.rs
1171 views
use crate::utils;1use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};2use std::path::Path;34#[derive(Debug)]5pub struct Silero {6session: ort::Session,7sample_rate: ArrayBase<OwnedRepr<i64>, Dim<[usize; 1]>>,8state: ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>,9}1011impl Silero {12pub fn new(13sample_rate: utils::SampleRate,14model_path: impl AsRef<Path>,15) -> Result<Self, ort::Error> {16let session = ort::Session::builder()?.commit_from_file(model_path)?;17let state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());18let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();19Ok(Self {20session,21sample_rate,22state,23})24}2526pub fn reset(&mut self) {27self.state = ArrayD::<f32>::zeros([2, 1, 128].as_slice());28}2930pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result<f32, ort::Error> {31let data = audio_frame32.iter()33.map(|x| (*x as f32) / (i16::MAX as f32))34.collect::<Vec<_>>();35let mut frame = Array2::<f32>::from_shape_vec([1, data.len()], data).unwrap();36frame = frame.slice(s![.., ..480]).to_owned();37let inps = ort::inputs![38frame,39std::mem::take(&mut self.state),40self.sample_rate.clone(),41]?;42let res = self43.session44.run(ort::SessionInputs::ValueSlice::<3>(&inps))?;45self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();46Ok(*res["output"]47.try_extract_raw_tensor::<f32>()48.unwrap()49.150.first()51.unwrap())52}53}545556