Path: blob/main/crates/wasi-nn/src/backend/onnx.rs
3119 views
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate.12use super::{3BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, NamedTensor,4};5use crate::backend::{Id, read};6use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};7use crate::{ExecutionContext, Graph};8use ort::{9execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},10inputs,11session::{Input, Output},12session::{Session, SessionInputValue, builder::GraphOptimizationLevel},13tensor::TensorElementType,14value::{Tensor as OrtTensor, ValueType},15};1617#[cfg(feature = "onnx-cuda")]18use ort::execution_providers::CUDAExecutionProvider;1920use std::path::Path;21use std::sync::{Arc, Mutex};2223#[derive(Default)]24pub struct OnnxBackend();25unsafe impl Send for OnnxBackend {}26unsafe impl Sync for OnnxBackend {}2728impl BackendInner for OnnxBackend {29fn encoding(&self) -> GraphEncoding {30GraphEncoding::Onnx31}3233fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {34if builders.len() != 1 {35return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));36}3738// Configure execution providers based on target39let execution_providers = configure_execution_providers(target)?;4041let session = Session::builder()?42.with_execution_providers(execution_providers)?43.with_optimization_level(GraphOptimizationLevel::Level3)?44.commit_from_memory(builders[0])?;4546let box_: Box<dyn BackendGraph> =47Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target));48Ok(box_.into())49}5051fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {52Some(self)53}54}5556/// Configure execution providers based on the target57fn configure_execution_providers(58target: ExecutionTarget,59) -> Result<Vec<ExecutionProviderDispatch>, BackendError> {60match target {61ExecutionTarget::Cpu => {62// Use CPU execution provider with default configuration63tracing::debug!("Using CPU execution provider");64Ok(vec![CPUExecutionProvider::default().build()])65}66ExecutionTarget::Gpu => {67#[cfg(feature = "onnx-cuda")]68{69// Use CUDA execution provider for GPU acceleration70tracing::debug!("Using Nvidia GPU CUDA execution provider");71Ok(vec![CUDAExecutionProvider::default().build()])72}73#[cfg(not(feature = "onnx-cuda"))]74{75tracing::warn!("GPU CUDA execution provider is not enabled, falling back to CPU");76Ok(vec![CPUExecutionProvider::default().build()])77}78}79ExecutionTarget::Tpu => {80tracing::warn!(81"TPU execution target is not supported for ONNX backend yet, falling back to CPU"82);83Ok(vec![CPUExecutionProvider::default().build()])84}85}86}8788impl BackendFromDir for OnnxBackend {89fn load_from_dir(90&mut self,91path: &Path,92target: ExecutionTarget,93) -> Result<Graph, BackendError> {94let model = read(&path.join("model.onnx"))?;95self.load(&[&model], target)96}97}9899struct OnnxGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget);100unsafe impl Send for OnnxGraph {}101unsafe impl Sync for OnnxGraph {}102103impl BackendGraph for OnnxGraph {104fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {105let session = self.0.lock().unwrap();106// We need to hold on to the names of the inputs in order for107// `set_input` to work with both indexes and names. Having the108// dimensions and type around is useful for validation but could be109// retrieved from the session.110let mut inputs = vec![];111for input in &session.inputs {112let shape = Shape::from_onnx_input(input)?;113inputs.push(TensorSlot {114shape,115tensor: None,116});117}118// We need to keep track of the output shapes since they are used for119// creating the output tensor.120let mut outputs = vec![];121for output in &session.outputs {122let shape = Shape::from_onnx_output(output)?;123outputs.push(TensorSlot {124shape,125tensor: None,126});127}128let box_: Box<dyn BackendExecutionContext> = Box::new(OnnxExecutionContext {129session: self.0.clone(),130inputs,131outputs,132});133Ok(box_.into())134}135}136137struct OnnxExecutionContext {138session: Arc<Mutex<Session>>,139inputs: Vec<TensorSlot>,140outputs: Vec<TensorSlot>,141}142143unsafe impl Send for OnnxExecutionContext {}144unsafe impl Sync for OnnxExecutionContext {}145146impl OnnxExecutionContext {147/// Helper function for finding the internal index of a tensor by [`Id`].148fn find(&self, id: Id, list: &[TensorSlot]) -> Result<usize, BackendError> {149let index = match id {150Id::Index(i) => {151let i = i as usize;152if i < list.len() {153i154} else {155return Err(BackendError::BackendAccess(wasmtime::format_err!(156"incorrect tensor index: {i} >= {}",157list.len()158)));159}160}161Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| {162BackendError::BackendAccess(wasmtime::format_err!("unknown tensor name: {n}"))163})?,164};165Ok(index)166}167}168169impl BackendExecutionContext for OnnxExecutionContext {170fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {171let index = self.find(id, &self.inputs)?;172let input = &mut self.inputs[index];173if let Err(e) = input.shape.matches(tensor) {174return Err(e.into());175}176// Hold the tensor data on the context until `compute` is called.177input.tensor.replace(tensor.clone());178Ok(())179}180181fn compute(182&mut self,183inputs: Option<Vec<NamedTensor>>,184) -> Result<Option<Vec<NamedTensor>>, BackendError> {185fn dimensions_as_u32(shape: &ort::tensor::Shape) -> Result<Vec<u32>, BackendError> {186(*shape)187.iter()188.map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) })189.collect()190}191192match inputs {193// WIT194Some(inputs) => {195for slot in &mut self.inputs {196slot.tensor = None;197}198for input in &inputs {199let index = self200.inputs201.iter()202.position(|slot| slot.shape.name == input.name);203let index = match index {204Some(idx) => idx,205None => {206// Try to convert name to index207if let Ok(idx) = input.name.parse::<usize>() {208if idx < self.inputs.len() {209idx210} else {211return Err(BackendError::BackendAccess(212wasmtime::format_err!("Input index out of range: {idx}"),213));214}215} else {216return Err(BackendError::BackendAccess(wasmtime::format_err!(217"Unknown input tensor name: {}",218input.name219)));220}221}222};223224let input_slot = &mut self.inputs[index];225if let Err(e) = input_slot.shape.matches(&input.tensor) {226return Err(e.into());227}228input_slot.tensor.replace(input.tensor.clone());229}230231let mut session_inputs: Vec<SessionInputValue<'_>> = vec![];232for i in &self.inputs {233session_inputs.extend(to_input_value(i)?);234}235let mut session = self.session.lock().unwrap();236let session_outputs = session.run(session_inputs.as_slice())?;237238let mut output_tensors = Vec::new();239for i in 0..self.outputs.len() {240// TODO: fix preexisting gap--this only handles f32 tensors.241let (shape, data): (&ort::tensor::Shape, &[f32]) =242session_outputs[i].try_extract_tensor()?;243let f32s = data.to_vec();244let output = &mut self.outputs[i];245let dimensions: Vec<u32> = dimensions_as_u32(shape)?;246let tensor = Tensor {247dimensions,248ty: output.shape.ty,249data: f32_vec_to_bytes(f32s),250};251output.tensor.replace(tensor.clone());252output_tensors.push(NamedTensor {253name: output.shape.name.clone(),254tensor,255});256}257Ok(Some(output_tensors))258}259260// WITX261None => {262let mut session_inputs: Vec<SessionInputValue<'_>> = vec![];263for i in &self.inputs {264session_inputs.extend(to_input_value(i)?);265}266let mut session = self.session.lock().unwrap();267let session_outputs = session.run(session_inputs.as_slice())?;268for i in 0..self.outputs.len() {269// TODO: fix preexisting gap--this only handles f32 tensors.270let (shape, data): (&ort::tensor::Shape, &[f32]) =271session_outputs[i].try_extract_tensor()?;272let f32s = data.to_vec();273let output = &mut self.outputs[i];274let dimensions: Vec<u32> = dimensions_as_u32(shape)?;275output.tensor.replace(Tensor {276dimensions,277ty: output.shape.ty,278data: f32_vec_to_bytes(f32s),279});280}281Ok(None)282}283}284}285286fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {287let index = self.find(id, &self.outputs)?;288let output = &self.outputs[index];289if let Some(tensor) = &output.tensor {290Ok(tensor.clone())291} else {292Err(BackendError::BackendAccess(wasmtime::format_err!(293"missing output tensor: {}; has `compute` been called?",294output.shape.name295)))296}297}298}299300impl From<ort::Error> for BackendError {301fn from(e: ort::Error) -> Self {302BackendError::BackendAccess(wasmtime::format_err!("{e}"))303}304}305306/// Holds a slot for ONNX session inputs and outputs.307///308/// TODO: it seems unfortunate that we have to "hold" some extra data per309/// session but in the input case, this is necessary for name-based indexing.310struct TensorSlot {311shape: Shape,312tensor: Option<Tensor>,313}314315/// Describes a tensor in ONNX terms.316struct Shape {317name: String,318dimensions: Vec<i64>,319ty: TensorType,320}321322impl Shape {323fn from_onnx_input(input: &Input) -> Result<Self, BackendError> {324let name = input.name.clone();325let (dimensions, ty) = convert_value_type(&input.input_type)?;326Ok(Self {327name,328dimensions,329ty,330})331}332333fn from_onnx_output(output: &Output) -> Result<Self, BackendError> {334let name = output.name.clone();335let (dimensions, ty) = convert_value_type(&output.output_type)?;336Ok(Self {337name,338dimensions,339ty,340})341}342343fn matches(&self, tensor: &Tensor) -> wasmtime::Result<()> {344if self.dimensions.len() != tensor.dimensions.len() {345return Err(wasmtime::format_err!(346"input tensor cardinality does not match model: {:?} != {:?}",347self.dimensions,348tensor.dimensions349));350} else {351for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) {352let tensor_dim = tensor_dim as i64;353if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim {354return Err(wasmtime::format_err!(355"input tensor dimensions do not match model: {:?} != {:?}",356self.dimensions,357tensor.dimensions358));359}360}361}362if self.ty != tensor.ty {363return Err(wasmtime::format_err!(364"input tensor type does not match model: {:?} != {:?}",365self.ty,366tensor.ty367));368}369Ok(())370}371}372373fn convert_value_type(vt: &ValueType) -> Result<(Vec<i64>, TensorType), BackendError> {374match vt {375ValueType::Tensor { ty, shape, .. } => {376let dimensions = shape.to_vec();377let ty = (*ty).try_into()?;378Ok((dimensions, ty))379}380_ => Err(BackendError::BackendAccess(wasmtime::format_err!(381"unsupported input type: {vt:?}"382))),383}384}385386fn convert_i64(i: &i64) -> Result<u32, BackendError> {387u32::try_from(*i).map_err(|d| -> BackendError {388wasmtime::format_err!("unable to convert dimension to u32: {d}").into()389})390}391392impl TryFrom<TensorElementType> for TensorType {393type Error = BackendError;394fn try_from(ty: TensorElementType) -> Result<Self, Self::Error> {395match ty {396TensorElementType::Float32 => Ok(TensorType::Fp32),397TensorElementType::Float64 => Ok(TensorType::Fp64),398TensorElementType::Uint8 => Ok(TensorType::U8),399TensorElementType::Int32 => Ok(TensorType::I32),400TensorElementType::Int64 => Ok(TensorType::I64),401_ => Err(BackendError::BackendAccess(wasmtime::format_err!(402"unsupported tensor type: {ty:?}"403))),404}405}406}407408fn to_input_value(slot: &TensorSlot) -> Result<[SessionInputValue<'_>; 1], BackendError> {409match &slot.tensor {410Some(tensor) => match tensor.ty {411TensorType::Fp32 => {412let data = bytes_to_f32_vec(tensor.data.to_vec());413let dimensions: Vec<i64> = tensor414.dimensions415.iter()416.map(|d| *d as i64) // TODO: fewer conversions417.collect();418let ort_tensor = OrtTensor::<f32>::from_array((dimensions, data)).map_err(|e| {419BackendError::BackendAccess(wasmtime::format_err!(420"failed to create ONNX session input: {e}"421))422})?;423Ok(inputs![ort_tensor])424}425_ => {426unimplemented!("{:?} not supported by ONNX", tensor.ty);427}428},429None => {430return Err(BackendError::BackendAccess(wasmtime::format_err!(431"missing input tensor: {}",432slot.shape.name433)));434}435}436}437438pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {439let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect();440let result: Vec<u8> = chunks.iter().flatten().copied().collect();441result442}443444pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {445let chunks: Vec<&[u8]> = data.chunks(4).collect();446let v: Vec<f32> = chunks447.into_iter()448.map(|c| f32::from_le_bytes(c.try_into().unwrap()))449.collect();450451v.into_iter().collect()452}453454/// Returns whether the dimension is dynamic.455///456/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the457/// value of a tensor dimension is user-defined, not fixed by the model. This is458/// useful for batching up several inference requests, e.g. When `ort` returns a459/// dimension of this kind, though, it uses `-1` to indicate that the dimension460/// is dynamic.461///462/// [dimensional variables]:463/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes464fn is_dynamic_dimension(d: i64) -> bool {465d == -1466}467468469