Path: blob/main/crates/wasi-nn/src/backend/mod.rs
2459 views
//! Define the Rust interface a backend must implement in order to be used by1//! this crate. The `Box<dyn ...>` types returned by these interfaces allow2//! implementations to maintain backend-specific state between calls.34#[cfg(feature = "onnx")]5pub mod onnx;6#[cfg(all(feature = "openvino", target_pointer_width = "64"))]7pub mod openvino;8#[cfg(feature = "pytorch")]9pub mod pytorch;10#[cfg(all(feature = "winml", target_os = "windows"))]11pub mod winml;1213#[cfg(feature = "onnx")]14use self::onnx::OnnxBackend;15#[cfg(all(feature = "openvino", target_pointer_width = "64"))]16use self::openvino::OpenvinoBackend;17#[cfg(feature = "pytorch")]18use self::pytorch::PytorchBackend;19#[cfg(all(feature = "winml", target_os = "windows"))]20use self::winml::WinMLBackend;2122use crate::wit::{ExecutionTarget, GraphEncoding, Tensor};23use crate::{Backend, ExecutionContext, Graph};24use std::fs::File;25use std::io::Read;26use std::path::Path;27use thiserror::Error;28use wiggle::GuestError;2930/// Return a list of all available backend frameworks.31pub fn list() -> Vec<Backend> {32let mut backends = vec![];33let _ = &mut backends; // silence warnings if none are enabled34#[cfg(all(feature = "openvino", target_pointer_width = "64"))]35{36backends.push(Backend::from(OpenvinoBackend::default()));37}38#[cfg(all(feature = "winml", target_os = "windows"))]39{40backends.push(Backend::from(WinMLBackend::default()));41}42#[cfg(feature = "onnx")]43{44backends.push(Backend::from(OnnxBackend::default()));45}46#[cfg(feature = "pytorch")]47{48backends.push(Backend::from(PytorchBackend::default()));49}50backends51}5253/// A [Backend] contains the necessary state to load [Graph]s.54pub trait BackendInner: Send + Sync {55fn encoding(&self) -> GraphEncoding;56fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;57fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>;58}5960/// Some [Backend]s support loading a [Graph] from a directory on the61/// filesystem; this is not a general requirement for backends but is useful for62/// the Wasmtime CLI.63pub trait BackendFromDir: BackendInner {64fn load_from_dir(65&mut self,66builders: &Path,67target: ExecutionTarget,68) -> Result<Graph, BackendError>;69}7071/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing72/// implementation for the user-facing graph.73pub trait BackendGraph: Send + Sync {74fn init_execution_context(&self) -> Result<ExecutionContext, BackendError>;75}7677/// A [BackendExecutionContext] performs the actual inference; this is the78/// backing implementation for a user-facing execution context.79pub trait BackendExecutionContext: Send + Sync {80// WITX functions81fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>;82fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError>;8384// Functions which work for both WIT and WITX85fn compute(86&mut self,87inputs: Option<Vec<NamedTensor>>,88) -> Result<Option<Vec<NamedTensor>>, BackendError>;89}9091/// An identifier for a tensor in a [Graph].92#[derive(Debug)]93pub enum Id {94Index(u32),95Name(String),96}97impl Id {98pub fn index(&self) -> Option<u32> {99match self {100Id::Index(i) => Some(*i),101Id::Name(_) => None,102}103}104pub fn name(&self) -> Option<&str> {105match self {106Id::Index(_) => None,107Id::Name(n) => Some(n),108}109}110}111112/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all113/// for failures interacting with the ML library.114#[derive(Debug, Error)]115pub enum BackendError {116#[error("Failed while accessing backend")]117BackendAccess(#[from] anyhow::Error),118#[error("Failed while accessing guest module")]119GuestAccess(#[from] GuestError),120#[error("The backend expects {0} buffers, passed {1}")]121InvalidNumberOfBuilders(usize, usize),122#[error("Not enough memory to copy tensor data of size: {0}")]123NotEnoughMemory(usize),124#[error("Unsupported tensor type: {0}")]125UnsupportedTensorType(String),126}127128/// Read a file into a byte vector.129#[allow(dead_code, reason = "not used on all platforms")]130fn read(path: &Path) -> anyhow::Result<Vec<u8>> {131let mut file = File::open(path)?;132let mut buffer = vec![];133file.read_to_end(&mut buffer)?;134Ok(buffer)135}136137pub struct NamedTensor {138pub name: String,139pub tensor: Tensor,140}141142143