Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/witx.rs
1692 views
1
//! Implements the `wasi-nn` API for the WITX ("preview1") ABI.
2
//!
3
//! `wasi-nn` was never included in the official "preview1" snapshot, but this
4
//! module implements the ABI that is compatible with "preview1".
5
//!
6
//! The only export from this module is [`add_to_linker`]. To implement it, this
7
//! module proceeds in steps:
8
//! 1. generate all of the Wiggle glue code into a `generated::*` namespace
9
//! 2. wire up the `generated::*` glue to the context state, delegating actual
10
//! computation to a `Backend`
11
//! 3. wrap up with some conversions, i.e., from `generated::*` types to this crate's
12
//! [`types`].
13
//!
14
//! [`types`]: crate::wit::types
15
16
use crate::backend::BackendError;
17
use crate::backend::Id;
18
use crate::wit::GraphEncoding;
19
use crate::{Backend, ExecutionContext, Graph, Registry};
20
use std::collections::HashMap;
21
use std::hash::Hash;
22
use thiserror::Error;
23
use wiggle::{GuestError, GuestMemory, GuestPtr};
24
25
pub use generated::wasi_ephemeral_nn::add_to_linker;
26
27
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
28
type Result<T> = WasiNnResult<T>;
29
type GraphId = u32;
30
type GraphExecutionContextId = u32;
31
32
/// Capture the state necessary for calling into the backend ML libraries.
33
pub struct WasiNnCtx {
34
pub(crate) backends: HashMap<GraphEncoding, Backend>,
35
pub(crate) registry: Registry,
36
pub(crate) graphs: Table<GraphId, Graph>,
37
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
38
}
39
40
impl WasiNnCtx {
41
/// Make a new context from the default state.
42
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
43
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
44
Self {
45
backends,
46
registry,
47
graphs: Table::default(),
48
executions: Table::default(),
49
}
50
}
51
}
52
53
/// Record handle entries in a table.
54
pub struct Table<K, V> {
55
entries: HashMap<K, V>,
56
next_key: u32,
57
}
58
59
impl<K, V> Default for Table<K, V> {
60
fn default() -> Self {
61
Self {
62
entries: HashMap::new(),
63
next_key: 0,
64
}
65
}
66
}
67
68
impl<K, V> Table<K, V>
69
where
70
K: Eq + Hash + From<u32> + Copy,
71
{
72
pub fn insert(&mut self, value: V) -> K {
73
let key = self.use_next_key();
74
self.entries.insert(key, value);
75
key
76
}
77
78
pub fn get(&self, key: K) -> Option<&V> {
79
self.entries.get(&key)
80
}
81
82
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
83
self.entries.get_mut(&key)
84
}
85
86
fn use_next_key(&mut self) -> K {
87
let current = self.next_key;
88
self.next_key += 1;
89
K::from(current)
90
}
91
}
92
93
/// Generate the traits and types from the `wasi-nn` WITX specification.
94
mod generated {
95
use super::*;
96
wiggle::from_witx!({
97
witx: ["witx/wasi-nn.witx"],
98
errors: { nn_errno => WasiNnError }
99
});
100
101
/// Additionally, we must let Wiggle know which of our error codes
102
/// represents a successful operation.
103
impl wiggle::GuestErrorType for types::NnErrno {
104
fn success() -> Self {
105
Self::Success
106
}
107
}
108
109
/// Convert the host errors to their WITX-generated type.
110
impl types::UserErrorConversion for WasiNnCtx {
111
fn nn_errno_from_wasi_nn_error(
112
&mut self,
113
e: WasiNnError,
114
) -> anyhow::Result<types::NnErrno> {
115
tracing::debug!("host error: {:?}", e);
116
match e {
117
WasiNnError::BackendError(_) => Ok(types::NnErrno::RuntimeError),
118
WasiNnError::GuestError(_) => unimplemented!("guest error conversion"),
119
WasiNnError::UsageError(_) => Ok(types::NnErrno::UnsupportedOperation),
120
WasiNnError::NotEnoughMemory(_) => Ok(types::NnErrno::TooLarge),
121
}
122
}
123
}
124
}
125
126
/// Wire up the WITX-generated trait to the `wasi-nn` host state.
127
impl generated::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
128
fn load(
129
&mut self,
130
memory: &mut GuestMemory<'_>,
131
builders: generated::types::GraphBuilderArray,
132
encoding: generated::types::GraphEncoding,
133
target: generated::types::ExecutionTarget,
134
) -> Result<generated::types::Graph> {
135
let graph = if let Some(backend) = self.backends.get_mut(&encoding.into()) {
136
// Retrieve all of the "builder lists" from the Wasm memory (see
137
// $graph_builder_array) as slices for a backend to operate on.
138
let mut slices = vec![];
139
for builder in builders.iter() {
140
let builder = memory.read(builder?)?;
141
let slice = memory.as_slice(builder)?.expect(
142
"cannot use with shared memories; \
143
see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)",
144
);
145
slices.push(slice);
146
}
147
let slice_refs = slices.iter().map(|s| s.as_ref()).collect::<Vec<_>>();
148
backend.load(&slice_refs, target.into())?
149
} else {
150
return Err(UsageError::InvalidEncoding(encoding.into()).into());
151
};
152
let graph_id = self.graphs.insert(graph);
153
Ok(graph_id.into())
154
}
155
156
fn load_by_name(
157
&mut self,
158
memory: &mut GuestMemory<'_>,
159
name: wiggle::GuestPtr<str>,
160
) -> Result<generated::types::Graph> {
161
let name = memory.as_str(name)?.unwrap();
162
if let Some(graph) = self.registry.get_mut(&name) {
163
let graph_id = self.graphs.insert(graph.clone());
164
Ok(graph_id.into())
165
} else {
166
return Err(UsageError::NotFound(name.to_string()).into());
167
}
168
}
169
170
fn init_execution_context(
171
&mut self,
172
_memory: &mut GuestMemory<'_>,
173
graph_id: generated::types::Graph,
174
) -> Result<generated::types::GraphExecutionContext> {
175
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id.into()) {
176
graph.init_execution_context()?
177
} else {
178
return Err(UsageError::InvalidGraphHandle.into());
179
};
180
181
let exec_context_id = self.executions.insert(exec_context);
182
Ok(exec_context_id.into())
183
}
184
185
fn set_input(
186
&mut self,
187
memory: &mut GuestMemory<'_>,
188
exec_context_id: generated::types::GraphExecutionContext,
189
index: u32,
190
tensor: &generated::types::Tensor,
191
) -> Result<()> {
192
if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
193
let tensor = crate::wit::types::Tensor {
194
dimensions: memory.to_vec(tensor.dimensions)?,
195
ty: tensor.type_.into(),
196
data: memory.to_vec(tensor.data)?,
197
};
198
Ok(exec_context.set_input(Id::Index(index), &tensor)?)
199
} else {
200
Err(UsageError::InvalidGraphHandle.into())
201
}
202
}
203
204
fn compute(
205
&mut self,
206
_memory: &mut GuestMemory<'_>,
207
exec_context_id: generated::types::GraphExecutionContext,
208
) -> Result<()> {
209
if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
210
Ok(exec_context.compute()?)
211
} else {
212
Err(UsageError::InvalidExecutionContextHandle.into())
213
}
214
}
215
216
fn get_output(
217
&mut self,
218
memory: &mut GuestMemory<'_>,
219
exec_context_id: generated::types::GraphExecutionContext,
220
index: u32,
221
out_buffer: GuestPtr<u8>,
222
out_buffer_max_size: u32,
223
) -> Result<u32> {
224
if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
225
let tensor = exec_context.get_output(Id::Index(index))?;
226
let destination = memory
227
.as_slice_mut(out_buffer.as_array(out_buffer_max_size))?
228
.expect(
229
"cannot use with shared memories; \
230
see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)",
231
);
232
if tensor.data.len() > destination.len() {
233
Err(WasiNnError::NotEnoughMemory(tensor.data.len()))
234
} else {
235
destination[..tensor.data.len()].copy_from_slice(&tensor.data);
236
Ok(tensor.data.len() as u32)
237
}
238
} else {
239
Err(UsageError::InvalidGraphHandle.into())
240
}
241
}
242
}
243
244
// Implement some conversion from `witx::types::*` to this crate's version.
245
246
impl From<generated::types::ExecutionTarget> for crate::wit::types::ExecutionTarget {
247
fn from(value: generated::types::ExecutionTarget) -> Self {
248
match value {
249
generated::types::ExecutionTarget::Cpu => crate::wit::types::ExecutionTarget::Cpu,
250
generated::types::ExecutionTarget::Gpu => crate::wit::types::ExecutionTarget::Gpu,
251
generated::types::ExecutionTarget::Tpu => crate::wit::types::ExecutionTarget::Tpu,
252
}
253
}
254
}
255
impl From<generated::types::GraphEncoding> for crate::wit::types::GraphEncoding {
256
fn from(value: generated::types::GraphEncoding) -> Self {
257
match value {
258
generated::types::GraphEncoding::Openvino => crate::wit::types::GraphEncoding::Openvino,
259
generated::types::GraphEncoding::Onnx => crate::wit::types::GraphEncoding::Onnx,
260
generated::types::GraphEncoding::Tensorflow => {
261
crate::wit::types::GraphEncoding::Tensorflow
262
}
263
generated::types::GraphEncoding::Pytorch => crate::wit::types::GraphEncoding::Pytorch,
264
generated::types::GraphEncoding::Tensorflowlite => {
265
crate::wit::types::GraphEncoding::Tensorflowlite
266
}
267
generated::types::GraphEncoding::Autodetect => {
268
crate::wit::types::GraphEncoding::Autodetect
269
}
270
}
271
}
272
}
273
impl From<generated::types::TensorType> for crate::wit::types::TensorType {
274
fn from(value: generated::types::TensorType) -> Self {
275
match value {
276
generated::types::TensorType::F16 => crate::wit::types::TensorType::Fp16,
277
generated::types::TensorType::F32 => crate::wit::types::TensorType::Fp32,
278
generated::types::TensorType::U8 => crate::wit::types::TensorType::U8,
279
generated::types::TensorType::I32 => crate::wit::types::TensorType::I32,
280
generated::types::TensorType::I64 => crate::wit::types::TensorType::I64,
281
generated::types::TensorType::F64 => crate::wit::types::TensorType::Fp64,
282
}
283
}
284
}
285
286
/// Possible errors while interacting with [WasiNnCtx].
287
#[derive(Debug, Error)]
288
pub enum WasiNnError {
289
#[error("backend error")]
290
BackendError(#[from] BackendError),
291
#[error("guest error")]
292
GuestError(#[from] GuestError),
293
#[error("usage error")]
294
UsageError(#[from] UsageError),
295
#[error("not enough memory: requested {0} bytes")]
296
NotEnoughMemory(usize),
297
}
298
299
#[derive(Debug, Error)]
300
pub enum UsageError {
301
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
302
InvalidEncoding(GraphEncoding),
303
#[error("Invalid graph handle; has it been loaded?")]
304
InvalidGraphHandle,
305
#[error("Invalid execution context handle; has it been initialized?")]
306
InvalidExecutionContextHandle,
307
#[error("No graph found with name: {0}")]
308
NotFound(String),
309
}
310
311