Path: blob/main/crates/wasi-nn/src/backend/onnx.rs
2459 views
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate.12use super::{3BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, NamedTensor,4};5use crate::backend::{Id, read};6use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};7use crate::{ExecutionContext, Graph};8use anyhow::Context;9use ort::{GraphOptimizationLevel, Session, inputs};10use std::path::Path;11use std::sync::{Arc, Mutex};1213#[derive(Default)]14pub struct OnnxBackend();15unsafe impl Send for OnnxBackend {}16unsafe impl Sync for OnnxBackend {}1718impl BackendInner for OnnxBackend {19fn encoding(&self) -> GraphEncoding {20GraphEncoding::Onnx21}2223fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {24if builders.len() != 1 {25return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into());26}2728let session = Session::builder()?29.with_optimization_level(GraphOptimizationLevel::Level3)?30.commit_from_memory(builders[0])?;3132let box_: Box<dyn BackendGraph> =33Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target));34Ok(box_.into())35}3637fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {38Some(self)39}40}4142impl BackendFromDir for OnnxBackend {43fn load_from_dir(44&mut self,45path: &Path,46target: ExecutionTarget,47) -> Result<Graph, BackendError> {48let model = read(&path.join("model.onnx"))?;49self.load(&[&model], target)50}51}5253struct OnnxGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget);54unsafe impl Send for OnnxGraph {}55unsafe impl Sync for OnnxGraph {}5657impl BackendGraph for OnnxGraph {58fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {59let session = self.0.lock().unwrap();60// We need to hold on to the names of the inputs in order for61// `set_input` to work with both indexes and names. Having the62// dimensions and type around is useful for validation but could be63// retrieved from the session.64let mut inputs = vec![];65for input in &session.inputs {66let shape = Shape::from_onnx_input(input)?;67inputs.push(TensorSlot {68shape,69tensor: None,70});71}72// We need to keep track of the output shapes since they are used for73// creating the output tensor.74let mut outputs = vec![];75for output in &session.outputs {76let shape = Shape::from_onnx_output(output)?;77outputs.push(TensorSlot {78shape,79tensor: None,80});81}82let box_: Box<dyn BackendExecutionContext> = Box::new(OnnxExecutionContext {83session: self.0.clone(),84inputs,85outputs,86});87Ok(box_.into())88}89}9091struct OnnxExecutionContext {92session: Arc<Mutex<Session>>,93inputs: Vec<TensorSlot>,94outputs: Vec<TensorSlot>,95}9697unsafe impl Send for OnnxExecutionContext {}98unsafe impl Sync for OnnxExecutionContext {}99100impl OnnxExecutionContext {101/// Helper function for finding the internal index of a tensor by [`Id`].102fn find(&self, id: Id, list: &[TensorSlot]) -> Result<usize, BackendError> {103let index = match id {104Id::Index(i) => {105let i = i as usize;106if i < list.len() {107i108} else {109return Err(BackendError::BackendAccess(anyhow::anyhow!(110"incorrect tensor index: {i} >= {}",111list.len()112)));113}114}115Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| {116BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {n}"))117})?,118};119Ok(index)120}121}122123impl BackendExecutionContext for OnnxExecutionContext {124fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {125let index = self.find(id, &self.inputs)?;126let input = &mut self.inputs[index];127if let Err(e) = input.shape.matches(tensor) {128return Err(e.into());129}130// Hold the tensor data on the context until `compute` is called.131input.tensor.replace(tensor.clone());132Ok(())133}134135fn compute(136&mut self,137inputs: Option<Vec<NamedTensor>>,138) -> Result<Option<Vec<NamedTensor>>, BackendError> {139match inputs {140// WIT141Some(inputs) => {142for slot in &mut self.inputs {143slot.tensor = None;144}145for input in &inputs {146let index = self147.inputs148.iter()149.position(|slot| slot.shape.name == input.name);150let index = match index {151Some(idx) => idx,152None => {153// Try to convert name to index154if let Ok(idx) = input.name.parse::<usize>() {155if idx < self.inputs.len() {156idx157} else {158return Err(BackendError::BackendAccess(anyhow::anyhow!(159"Input index out of range: {}",160idx161)));162}163} else {164return Err(BackendError::BackendAccess(anyhow::anyhow!(165"Unknown input tensor name: {}",166input.name167)));168}169}170};171172let input_slot = &mut self.inputs[index];173if let Err(e) = input_slot.shape.matches(&input.tensor) {174return Err(e.into());175}176input_slot.tensor.replace(input.tensor.clone());177}178179let mut session_inputs: Vec<ort::SessionInputValue<'_>> = vec![];180for i in &self.inputs {181session_inputs.extend(to_input_value(i)?);182}183let session = self.session.lock().unwrap();184let session_outputs = session.run(session_inputs.as_slice())?;185186let mut output_tensors = Vec::new();187for i in 0..self.outputs.len() {188// TODO: fix preexisting gap--this only handles f32 tensors.189let raw: (Vec<i64>, &[f32]) = session_outputs[i].try_extract_raw_tensor()?;190let f32s = raw.1.to_vec();191let output = &mut self.outputs[i];192let tensor = Tensor {193dimensions: output.shape.dimensions_as_u32()?,194ty: output.shape.ty,195data: f32_vec_to_bytes(f32s),196};197output.tensor.replace(tensor.clone());198output_tensors.push(NamedTensor {199name: output.shape.name.clone(),200tensor,201});202}203Ok(Some(output_tensors))204}205206// WITX207None => {208let mut session_inputs: Vec<ort::SessionInputValue<'_>> = vec![];209for i in &self.inputs {210session_inputs.extend(to_input_value(i)?);211}212let session = self.session.lock().unwrap();213let session_outputs = session.run(session_inputs.as_slice())?;214for i in 0..self.outputs.len() {215// TODO: fix preexisting gap--this only handles f32 tensors.216let raw: (Vec<i64>, &[f32]) = session_outputs[i].try_extract_raw_tensor()?;217let f32s = raw.1.to_vec();218let output = &mut self.outputs[i];219output.tensor.replace(Tensor {220dimensions: output.shape.dimensions_as_u32()?,221ty: output.shape.ty,222data: f32_vec_to_bytes(f32s),223});224}225Ok(None)226}227}228}229230fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {231let index = self.find(id, &self.outputs)?;232let output = &self.outputs[index];233if let Some(tensor) = &output.tensor {234Ok(tensor.clone())235} else {236Err(BackendError::BackendAccess(anyhow::anyhow!(237"missing output tensor: {}; has `compute` been called?",238output.shape.name239)))240}241}242}243244impl From<ort::Error> for BackendError {245fn from(e: ort::Error) -> Self {246BackendError::BackendAccess(e.into())247}248}249250/// Holds a slot for ONNX session inputs and outputs.251///252/// TODO: it seems unfortunate that we have to "hold" some extra data per253/// session but in the input case, this is necessary for name-based indexing.254struct TensorSlot {255shape: Shape,256tensor: Option<Tensor>,257}258259/// Describes a tensor in ONNX terms.260struct Shape {261name: String,262dimensions: Vec<i64>,263ty: TensorType,264}265266impl Shape {267fn from_onnx_input(input: &ort::Input) -> Result<Self, BackendError> {268let name = input.name.clone();269let (dimensions, ty) = convert_value_type(&input.input_type)?;270Ok(Self {271name,272dimensions,273ty,274})275}276277fn from_onnx_output(output: &ort::Output) -> Result<Self, BackendError> {278let name = output.name.clone();279let (dimensions, ty) = convert_value_type(&output.output_type)?;280Ok(Self {281name,282dimensions,283ty,284})285}286287fn dimensions_as_u32(&self) -> Result<Vec<u32>, BackendError> {288self.dimensions289.iter()290.map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) })291.collect()292}293294fn matches(&self, tensor: &Tensor) -> anyhow::Result<()> {295if self.dimensions.len() != tensor.dimensions.len() {296return Err(anyhow::anyhow!(297"input tensor cardinality does not match model: {:?} != {:?}",298self.dimensions,299tensor.dimensions300));301} else {302for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) {303let tensor_dim = tensor_dim as i64;304if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim {305return Err(anyhow::anyhow!(306"input tensor dimensions do not match model: {:?} != {:?}",307self.dimensions,308tensor.dimensions309));310}311}312}313if self.ty != tensor.ty {314return Err(anyhow::anyhow!(315"input tensor type does not match model: {:?} != {:?}",316self.ty,317tensor.ty318));319}320Ok(())321}322}323324fn convert_value_type(vt: &ort::ValueType) -> Result<(Vec<i64>, TensorType), BackendError> {325match vt {326ort::ValueType::Tensor { ty, dimensions } => {327let dims = dimensions.clone();328let ty = (*ty).try_into()?;329Ok((dims, ty))330}331_ => Err(BackendError::BackendAccess(anyhow::anyhow!(332"unsupported input type: {vt:?}"333))),334}335}336337fn convert_i64(i: &i64) -> Result<u32, BackendError> {338u32::try_from(*i).map_err(|d| -> BackendError {339anyhow::anyhow!("unable to convert dimension to u32: {d}").into()340})341}342343impl TryFrom<ort::TensorElementType> for TensorType {344type Error = BackendError;345fn try_from(ty: ort::TensorElementType) -> Result<Self, Self::Error> {346match ty {347ort::TensorElementType::Float32 => Ok(TensorType::Fp32),348ort::TensorElementType::Float64 => Ok(TensorType::Fp64),349ort::TensorElementType::Uint8 => Ok(TensorType::U8),350ort::TensorElementType::Int32 => Ok(TensorType::I32),351ort::TensorElementType::Int64 => Ok(TensorType::I64),352_ => Err(BackendError::BackendAccess(anyhow::anyhow!(353"unsupported tensor type: {ty:?}"354))),355}356}357}358359fn to_input_value(slot: &TensorSlot) -> Result<[ort::SessionInputValue<'_>; 1], BackendError> {360match &slot.tensor {361Some(tensor) => match tensor.ty {362TensorType::Fp32 => {363let data = bytes_to_f32_vec(tensor.data.to_vec());364let dimensions = tensor365.dimensions366.iter()367.map(|d| *d as i64) // TODO: fewer conversions368.collect::<Vec<i64>>();369Ok(inputs![(dimensions, Arc::new(data.into_boxed_slice()))]370.context("failed to create ONNX session input")?)371}372_ => {373unimplemented!("{:?} not supported by ONNX", tensor.ty);374}375},376None => {377return Err(BackendError::BackendAccess(anyhow::anyhow!(378"missing input tensor: {}",379slot.shape.name380)));381}382}383}384385pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {386let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect();387let result: Vec<u8> = chunks.iter().flatten().copied().collect();388result389}390391pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {392let chunks: Vec<&[u8]> = data.chunks(4).collect();393let v: Vec<f32> = chunks394.into_iter()395.map(|c| f32::from_le_bytes(c.try_into().unwrap()))396.collect();397398v.into_iter().collect()399}400401/// Returns whether the dimension is dynamic.402///403/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the404/// value of a tensor dimension is user-defined, not fixed by the model. This is405/// useful for batching up several inference requests, e.g. When `ort` returns a406/// dimension of this kind, though, it uses `-1` to indicate that the dimension407/// is dynamic.408///409/// [dimensional variables]:410/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes411fn is_dynamic_dimension(d: i64) -> bool {412d == -1413}414415416