Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/wit.rs
3076 views
1
//! Implements the `wasi-nn` API for the WIT ("preview2") ABI.
2
//!
3
//! Note that `wasi-nn` is not yet included in an official "preview2" world
4
//! (though it could be) so by "preview2" here we mean that this can be called
5
//! with the component model's canonical ABI.
6
//!
7
//! This module exports its [`types`] for use throughout the crate and the
8
//! [`ML`] object, which exposes [`ML::add_to_linker`]. To implement all of
9
//! this, this module proceeds in steps:
10
//! 1. generate all of the WIT glue code into a `generated::*` namespace
11
//! 2. wire up the `generated::*` glue to the context state, delegating actual
12
//! computation to a [`Backend`]
13
//! 3. convert some types
14
//!
15
//! [`Backend`]: crate::Backend
16
//! [`types`]: crate::wit::types
17
18
use crate::{Backend, Registry};
19
use std::collections::HashMap;
20
use std::hash::Hash;
21
use std::{fmt, str::FromStr};
22
use wasmtime::component::{HasData, Resource, ResourceTable};
23
use wasmtime::format_err;
24
25
/// Capture the state necessary for calling into the backend ML libraries.
26
pub struct WasiNnCtx {
27
pub(crate) backends: HashMap<GraphEncoding, Backend>,
28
pub(crate) registry: Registry,
29
}
30
31
impl WasiNnCtx {
32
/// Make a new context from the default state.
33
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
34
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
35
Self { backends, registry }
36
}
37
}
38
39
/// A wrapper capturing the needed internal wasi-nn state.
40
///
41
/// Unlike other WASI proposals (see `wasmtime-wasi`, `wasmtime-wasi-http`),
42
/// this wrapper is not a `trait` but rather holds the references directly. This
43
/// remove one layer of abstraction for simplicity only, and could be added back
44
/// in the future if embedders need more control here.
45
pub struct WasiNnView<'a> {
46
ctx: &'a mut WasiNnCtx,
47
table: &'a mut ResourceTable,
48
}
49
50
impl<'a> WasiNnView<'a> {
51
/// Create a new view into the wasi-nn state.
52
pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
53
Self { ctx, table }
54
}
55
}
56
57
/// A wasi-nn error; this appears on the Wasm side as a component model
58
/// resource.
59
#[derive(Debug)]
60
pub struct Error {
61
code: ErrorCode,
62
data: wasmtime::Error,
63
}
64
65
/// Construct an [`Error`] resource and immediately return it.
66
///
67
/// The WIT specification currently relies on "errors as resources;" this helper
68
/// macro hides some of that complexity. If [#75] is adopted ("errors as
69
/// records"), this macro is no longer necessary.
70
///
71
/// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75
72
macro_rules! bail {
73
($self:ident, $code:expr, $data:expr) => {
74
let e = Error {
75
code: $code,
76
data: $data.into(),
77
};
78
tracing::error!("failure: {e:?}");
79
let r = $self.table.push(e)?;
80
return Ok(Err(r));
81
};
82
}
83
84
impl From<wasmtime::component::ResourceTableError> for Error {
85
fn from(error: wasmtime::component::ResourceTableError) -> Self {
86
Self {
87
code: ErrorCode::Trap,
88
data: error.into(),
89
}
90
}
91
}
92
93
/// The list of error codes available to the `wasi-nn` API; this should match
94
/// what is specified in WIT.
95
#[derive(Debug)]
96
pub enum ErrorCode {
97
/// Caller module passed an invalid argument.
98
InvalidArgument,
99
/// Invalid encoding.
100
InvalidEncoding,
101
/// The operation timed out.
102
Timeout,
103
/// Runtime error.
104
RuntimeError,
105
/// Unsupported operation.
106
UnsupportedOperation,
107
/// Graph is too large.
108
TooLarge,
109
/// Graph not found.
110
NotFound,
111
/// A runtime error that Wasmtime should trap on; this will not appear in
112
/// the WIT specification.
113
Trap,
114
}
115
116
/// Generate the traits and types from the `wasi-nn` WIT specification.
117
pub(crate) mod generated_ {
118
wasmtime::component::bindgen!({
119
world: "ml",
120
path: "wit/wasi-nn.wit",
121
with: {
122
// Configure all WIT http resources to be defined types in this
123
// crate to use the `ResourceTable` helper methods.
124
"wasi:nn/graph.graph": crate::Graph,
125
"wasi:nn/tensor.tensor": crate::Tensor,
126
"wasi:nn/inference.graph-execution-context": crate::ExecutionContext,
127
"wasi:nn/errors.error": super::Error,
128
},
129
imports: { default: trappable },
130
trappable_error_type: {
131
"wasi:nn/errors.error" => super::Error,
132
},
133
});
134
}
135
use generated_::wasi::nn::{self as generated}; // Shortcut to the module containing the types we need.
136
137
// Export the `types` used in this crate as well as `ML::add_to_linker`.
138
pub mod types {
139
use super::generated;
140
pub use generated::errors::Error;
141
pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
142
pub use generated::inference::GraphExecutionContext;
143
pub use generated::tensor::{Tensor, TensorType};
144
}
145
pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
146
pub use generated::inference::{GraphExecutionContext, NamedTensor};
147
pub use generated::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
148
pub use generated_::Ml as ML;
149
150
/// Add the WIT-based version of the `wasi-nn` API to a
151
/// [`wasmtime::component::Linker`].
152
pub fn add_to_linker<T: 'static>(
153
l: &mut wasmtime::component::Linker<T>,
154
f: fn(&mut T) -> WasiNnView<'_>,
155
) -> wasmtime::Result<()> {
156
generated::graph::add_to_linker::<_, HasWasiNnView>(l, f)?;
157
generated::tensor::add_to_linker::<_, HasWasiNnView>(l, f)?;
158
generated::inference::add_to_linker::<_, HasWasiNnView>(l, f)?;
159
generated::errors::add_to_linker::<_, HasWasiNnView>(l, f)?;
160
Ok(())
161
}
162
163
struct HasWasiNnView;
164
165
impl HasData for HasWasiNnView {
166
type Data<'a> = WasiNnView<'a>;
167
}
168
169
impl generated::graph::Host for WasiNnView<'_> {
170
fn load(
171
&mut self,
172
builders: Vec<GraphBuilder>,
173
encoding: GraphEncoding,
174
target: ExecutionTarget,
175
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
176
tracing::debug!("load {encoding:?} {target:?}");
177
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
178
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
179
match backend.load(&slices, target) {
180
Ok(graph) => {
181
let graph = self.table.push(graph)?;
182
Ok(Ok(graph))
183
}
184
Err(error) => {
185
bail!(self, ErrorCode::RuntimeError, error);
186
}
187
}
188
} else {
189
bail!(
190
self,
191
ErrorCode::InvalidEncoding,
192
format_err!("unable to find a backend for this encoding")
193
);
194
}
195
}
196
197
fn load_by_name(
198
&mut self,
199
name: String,
200
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
201
use core::result::Result::*;
202
tracing::debug!("load by name {name:?}");
203
let registry = &self.ctx.registry;
204
if let Some(graph) = registry.get(&name) {
205
let graph = graph.clone();
206
let graph = self.table.push(graph)?;
207
Ok(Ok(graph))
208
} else {
209
bail!(
210
self,
211
ErrorCode::NotFound,
212
format_err!("failed to find graph with name: {name}")
213
);
214
}
215
}
216
}
217
218
impl generated::graph::HostGraph for WasiNnView<'_> {
219
fn init_execution_context(
220
&mut self,
221
graph: Resource<Graph>,
222
) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
223
use core::result::Result::*;
224
tracing::debug!("initialize execution context");
225
let graph = self.table.get(&graph)?;
226
match graph.init_execution_context() {
227
Ok(exec_context) => {
228
let exec_context = self.table.push(exec_context)?;
229
Ok(Ok(exec_context))
230
}
231
Err(error) => {
232
bail!(self, ErrorCode::RuntimeError, error);
233
}
234
}
235
}
236
237
fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
238
self.table.delete(graph)?;
239
Ok(())
240
}
241
}
242
243
impl generated::inference::HostGraphExecutionContext for WasiNnView<'_> {
244
fn compute(
245
&mut self,
246
exec_context: Resource<GraphExecutionContext>,
247
inputs: Vec<NamedTensor>,
248
) -> wasmtime::Result<Result<Vec<NamedTensor>, Resource<Error>>> {
249
tracing::debug!("compute with {} inputs", inputs.len());
250
251
let mut named_tensors = Vec::new();
252
for (name, tensor_resopurce) in inputs.into_iter() {
253
let tensor = self.table.delete(tensor_resopurce)?;
254
named_tensors.push(crate::backend::NamedTensor { name, tensor });
255
}
256
257
let exec_context = &mut self.table.get_mut(&exec_context)?;
258
259
match exec_context.compute_with_io(named_tensors) {
260
Ok(named_tensors) => {
261
let result = named_tensors
262
.into_iter()
263
.map(|crate::backend::NamedTensor { name, tensor }| {
264
self.table.push(tensor).map(|resource| (name, resource))
265
})
266
.collect();
267
268
match result {
269
Ok(tuples) => Ok(Ok(tuples)),
270
Err(error) => {
271
bail!(self, ErrorCode::RuntimeError, error);
272
}
273
}
274
}
275
Err(error) => {
276
bail!(self, ErrorCode::RuntimeError, error);
277
}
278
}
279
}
280
281
fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
282
self.table.delete(exec_context)?;
283
Ok(())
284
}
285
}
286
287
impl generated::tensor::HostTensor for WasiNnView<'_> {
288
fn new(
289
&mut self,
290
dimensions: TensorDimensions,
291
ty: TensorType,
292
data: TensorData,
293
) -> wasmtime::Result<Resource<Tensor>> {
294
let tensor = Tensor {
295
dimensions,
296
ty,
297
data,
298
};
299
let tensor = self.table.push(tensor)?;
300
Ok(tensor)
301
}
302
303
fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
304
let tensor = self.table.get(&tensor)?;
305
Ok(tensor.dimensions.clone())
306
}
307
308
fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
309
let tensor = self.table.get(&tensor)?;
310
Ok(tensor.ty)
311
}
312
313
fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
314
let tensor = self.table.get(&tensor)?;
315
Ok(tensor.data.clone())
316
}
317
318
fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
319
self.table.delete(tensor)?;
320
Ok(())
321
}
322
}
323
324
impl generated::errors::HostError for WasiNnView<'_> {
325
fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<generated::errors::ErrorCode> {
326
let error = self.table.get(&error)?;
327
match error.code {
328
ErrorCode::InvalidArgument => Ok(generated::errors::ErrorCode::InvalidArgument),
329
ErrorCode::InvalidEncoding => Ok(generated::errors::ErrorCode::InvalidEncoding),
330
ErrorCode::Timeout => Ok(generated::errors::ErrorCode::Timeout),
331
ErrorCode::RuntimeError => Ok(generated::errors::ErrorCode::RuntimeError),
332
ErrorCode::UnsupportedOperation => {
333
Ok(generated::errors::ErrorCode::UnsupportedOperation)
334
}
335
ErrorCode::TooLarge => Ok(generated::errors::ErrorCode::TooLarge),
336
ErrorCode::NotFound => Ok(generated::errors::ErrorCode::NotFound),
337
ErrorCode::Trap => Err(format_err!(error.data.to_string())),
338
}
339
}
340
341
fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
342
let error = self.table.get(&error)?;
343
Ok(error.data.to_string())
344
}
345
346
fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
347
self.table.delete(error)?;
348
Ok(())
349
}
350
}
351
352
impl generated::errors::Host for WasiNnView<'_> {
353
fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
354
if matches!(err.code, ErrorCode::Trap) {
355
Err(err.data)
356
} else {
357
Ok(err)
358
}
359
}
360
}
361
362
impl generated::tensor::Host for WasiNnView<'_> {}
363
impl generated::inference::Host for WasiNnView<'_> {}
364
365
impl Hash for generated::graph::GraphEncoding {
366
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
367
self.to_string().hash(state)
368
}
369
}
370
371
impl fmt::Display for generated::graph::GraphEncoding {
372
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373
use generated::graph::GraphEncoding::*;
374
match self {
375
Openvino => write!(f, "openvino"),
376
Onnx => write!(f, "onnx"),
377
Pytorch => write!(f, "pytorch"),
378
Tensorflow => write!(f, "tensorflow"),
379
Tensorflowlite => write!(f, "tensorflowlite"),
380
Autodetect => write!(f, "autodetect"),
381
Ggml => write!(f, "ggml"),
382
}
383
}
384
}
385
386
impl FromStr for generated::graph::GraphEncoding {
387
type Err = GraphEncodingParseError;
388
fn from_str(s: &str) -> Result<Self, Self::Err> {
389
match s.to_lowercase().as_str() {
390
"openvino" => Ok(generated::graph::GraphEncoding::Openvino),
391
"onnx" => Ok(generated::graph::GraphEncoding::Onnx),
392
"pytorch" => Ok(generated::graph::GraphEncoding::Pytorch),
393
"tensorflow" => Ok(generated::graph::GraphEncoding::Tensorflow),
394
"tensorflowlite" => Ok(generated::graph::GraphEncoding::Tensorflowlite),
395
"autodetect" => Ok(generated::graph::GraphEncoding::Autodetect),
396
_ => Err(GraphEncodingParseError(s.into())),
397
}
398
}
399
}
400
#[derive(Debug)]
401
pub struct GraphEncodingParseError(String);
402
impl fmt::Display for GraphEncodingParseError {
403
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404
write!(f, "unknown graph encoding: {}", self.0)
405
}
406
}
407
impl std::error::Error for GraphEncodingParseError {}
408
409