Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/backend/mod.rs
2459 views
1
//! Define the Rust interface a backend must implement in order to be used by
2
//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
3
//! implementations to maintain backend-specific state between calls.
4
5
#[cfg(feature = "onnx")]
6
pub mod onnx;
7
#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
8
pub mod openvino;
9
#[cfg(feature = "pytorch")]
10
pub mod pytorch;
11
#[cfg(all(feature = "winml", target_os = "windows"))]
12
pub mod winml;
13
14
#[cfg(feature = "onnx")]
15
use self::onnx::OnnxBackend;
16
#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
17
use self::openvino::OpenvinoBackend;
18
#[cfg(feature = "pytorch")]
19
use self::pytorch::PytorchBackend;
20
#[cfg(all(feature = "winml", target_os = "windows"))]
21
use self::winml::WinMLBackend;
22
23
use crate::wit::{ExecutionTarget, GraphEncoding, Tensor};
24
use crate::{Backend, ExecutionContext, Graph};
25
use std::fs::File;
26
use std::io::Read;
27
use std::path::Path;
28
use thiserror::Error;
29
use wiggle::GuestError;
30
31
/// Return a list of all available backend frameworks.
32
pub fn list() -> Vec<Backend> {
33
let mut backends = vec![];
34
let _ = &mut backends; // silence warnings if none are enabled
35
#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
36
{
37
backends.push(Backend::from(OpenvinoBackend::default()));
38
}
39
#[cfg(all(feature = "winml", target_os = "windows"))]
40
{
41
backends.push(Backend::from(WinMLBackend::default()));
42
}
43
#[cfg(feature = "onnx")]
44
{
45
backends.push(Backend::from(OnnxBackend::default()));
46
}
47
#[cfg(feature = "pytorch")]
48
{
49
backends.push(Backend::from(PytorchBackend::default()));
50
}
51
backends
52
}
53
54
/// A [Backend] contains the necessary state to load [Graph]s.
55
pub trait BackendInner: Send + Sync {
56
fn encoding(&self) -> GraphEncoding;
57
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
58
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>;
59
}
60
61
/// Some [Backend]s support loading a [Graph] from a directory on the
62
/// filesystem; this is not a general requirement for backends but is useful for
63
/// the Wasmtime CLI.
64
pub trait BackendFromDir: BackendInner {
65
fn load_from_dir(
66
&mut self,
67
builders: &Path,
68
target: ExecutionTarget,
69
) -> Result<Graph, BackendError>;
70
}
71
72
/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
73
/// implementation for the user-facing graph.
74
pub trait BackendGraph: Send + Sync {
75
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError>;
76
}
77
78
/// A [BackendExecutionContext] performs the actual inference; this is the
79
/// backing implementation for a user-facing execution context.
80
pub trait BackendExecutionContext: Send + Sync {
81
// WITX functions
82
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>;
83
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError>;
84
85
// Functions which work for both WIT and WITX
86
fn compute(
87
&mut self,
88
inputs: Option<Vec<NamedTensor>>,
89
) -> Result<Option<Vec<NamedTensor>>, BackendError>;
90
}
91
92
/// An identifier for a tensor in a [Graph].
93
#[derive(Debug)]
94
pub enum Id {
95
Index(u32),
96
Name(String),
97
}
98
impl Id {
99
pub fn index(&self) -> Option<u32> {
100
match self {
101
Id::Index(i) => Some(*i),
102
Id::Name(_) => None,
103
}
104
}
105
pub fn name(&self) -> Option<&str> {
106
match self {
107
Id::Index(_) => None,
108
Id::Name(n) => Some(n),
109
}
110
}
111
}
112
113
/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all
114
/// for failures interacting with the ML library.
115
#[derive(Debug, Error)]
116
pub enum BackendError {
117
#[error("Failed while accessing backend")]
118
BackendAccess(#[from] anyhow::Error),
119
#[error("Failed while accessing guest module")]
120
GuestAccess(#[from] GuestError),
121
#[error("The backend expects {0} buffers, passed {1}")]
122
InvalidNumberOfBuilders(usize, usize),
123
#[error("Not enough memory to copy tensor data of size: {0}")]
124
NotEnoughMemory(usize),
125
#[error("Unsupported tensor type: {0}")]
126
UnsupportedTensorType(String),
127
}
128
129
/// Read a file into a byte vector.
130
#[allow(dead_code, reason = "not used on all platforms")]
131
fn read(path: &Path) -> anyhow::Result<Vec<u8>> {
132
let mut file = File::open(path)?;
133
let mut buffer = vec![];
134
file.read_to_end(&mut buffer)?;
135
Ok(buffer)
136
}
137
138
pub struct NamedTensor {
139
pub name: String,
140
pub tensor: Tensor,
141
}
142
143