Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/lib.rs
1692 views
1
pub mod backend;
2
mod registry;
3
pub mod wit;
4
pub mod witx;
5
6
use crate::backend::{BackendError, Id, NamedTensor as BackendNamedTensor};
7
use crate::wit::generated_::wasi::nn::tensor::TensorType;
8
use anyhow::anyhow;
9
use core::fmt;
10
pub use registry::{GraphRegistry, InMemoryRegistry};
11
use std::path::Path;
12
use std::sync::Arc;
13
14
/// Construct an in-memory registry from the available backends and a list of
15
/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
16
/// from a local directory, which is a safe assumption currently for the current
17
/// model types.
18
pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec<Backend>, Registry)> {
19
let mut backends = backend::list();
20
let mut registry = InMemoryRegistry::new();
21
for (kind, path) in preload_graphs {
22
let kind_ = kind.parse()?;
23
let backend = backends
24
.iter_mut()
25
.find(|b| b.encoding() == kind_)
26
.ok_or(anyhow!("unsupported backend: {}", kind))?
27
.as_dir_loadable()
28
.ok_or(anyhow!("{} does not support directory loading", kind))?;
29
registry.load(backend, Path::new(path))?;
30
}
31
Ok((backends, Registry::from(registry)))
32
}
33
34
/// A machine learning backend.
35
pub struct Backend(Box<dyn backend::BackendInner>);
36
impl std::ops::Deref for Backend {
37
type Target = dyn backend::BackendInner;
38
fn deref(&self) -> &Self::Target {
39
self.0.as_ref()
40
}
41
}
42
impl std::ops::DerefMut for Backend {
43
fn deref_mut(&mut self) -> &mut Self::Target {
44
self.0.as_mut()
45
}
46
}
47
impl<T: backend::BackendInner + 'static> From<T> for Backend {
48
fn from(value: T) -> Self {
49
Self(Box::new(value))
50
}
51
}
52
53
/// A backend-defined graph (i.e., ML model).
54
#[derive(Clone)]
55
pub struct Graph(Arc<dyn backend::BackendGraph>);
56
impl From<Box<dyn backend::BackendGraph>> for Graph {
57
fn from(value: Box<dyn backend::BackendGraph>) -> Self {
58
Self(value.into())
59
}
60
}
61
impl std::ops::Deref for Graph {
62
type Target = dyn backend::BackendGraph;
63
fn deref(&self) -> &Self::Target {
64
self.0.as_ref()
65
}
66
}
67
68
/// A host-side tensor.
69
///
70
/// Eventually, this may be defined in each backend as they gain the ability to
71
/// hold tensors on various devices (TODO:
72
/// https://github.com/WebAssembly/wasi-nn/pull/70).
73
#[derive(Clone, PartialEq)]
74
pub struct Tensor {
75
pub dimensions: Vec<u32>,
76
pub ty: TensorType,
77
pub data: Vec<u8>,
78
}
79
impl fmt::Debug for Tensor {
80
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81
f.debug_struct("Tensor")
82
.field("dimensions", &self.dimensions)
83
.field("ty", &self.ty)
84
.field("data (bytes)", &self.data.len())
85
.finish()
86
}
87
}
88
89
/// A backend-defined execution context.
90
pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
91
impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {
92
fn from(value: Box<dyn backend::BackendExecutionContext>) -> Self {
93
Self(value)
94
}
95
}
96
impl std::ops::Deref for ExecutionContext {
97
type Target = dyn backend::BackendExecutionContext;
98
fn deref(&self) -> &Self::Target {
99
self.0.as_ref()
100
}
101
}
102
impl std::ops::DerefMut for ExecutionContext {
103
fn deref_mut(&mut self) -> &mut Self::Target {
104
self.0.as_mut()
105
}
106
}
107
108
/// A container for graphs.
109
pub struct Registry(Box<dyn GraphRegistry>);
110
impl std::ops::Deref for Registry {
111
type Target = dyn GraphRegistry;
112
fn deref(&self) -> &Self::Target {
113
self.0.as_ref()
114
}
115
}
116
impl std::ops::DerefMut for Registry {
117
fn deref_mut(&mut self) -> &mut Self::Target {
118
self.0.as_mut()
119
}
120
}
121
impl<T> From<T> for Registry
122
where
123
T: GraphRegistry + 'static,
124
{
125
fn from(value: T) -> Self {
126
Self(Box::new(value))
127
}
128
}
129
130
impl ExecutionContext {
131
pub fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
132
self.0.set_input(id, tensor)
133
}
134
135
pub fn compute(&mut self) -> Result<(), BackendError> {
136
self.0.compute(None).map(|_| ())
137
}
138
139
pub fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
140
self.0.get_output(id)
141
}
142
143
pub fn compute_with_io(
144
&mut self,
145
inputs: Vec<BackendNamedTensor>,
146
) -> Result<Vec<BackendNamedTensor>, BackendError> {
147
match self.0.compute(Some(inputs))? {
148
Some(outputs) => Ok(outputs),
149
None => Ok(Vec::new()),
150
}
151
}
152
}
153
154
impl Tensor {
155
pub fn new(dimensions: Vec<u32>, ty: TensorType, data: Vec<u8>) -> Self {
156
Self {
157
dimensions,
158
ty,
159
data,
160
}
161
}
162
}
163
164