Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/wit.rs
1692 views
1
//! Implements the `wasi-nn` API for the WIT ("preview2") ABI.
2
//!
3
//! Note that `wasi-nn` is not yet included in an official "preview2" world
4
//! (though it could be) so by "preview2" here we mean that this can be called
5
//! with the component model's canonical ABI.
6
//!
7
//! This module exports its [`types`] for use throughout the crate and the
8
//! [`ML`] object, which exposes [`ML::add_to_linker`]. To implement all of
9
//! this, this module proceeds in steps:
10
//! 1. generate all of the WIT glue code into a `generated::*` namespace
11
//! 2. wire up the `generated::*` glue to the context state, delegating actual
12
//! computation to a [`Backend`]
13
//! 3. convert some types
14
//!
15
//! [`Backend`]: crate::Backend
16
//! [`types`]: crate::wit::types
17
18
use crate::{Backend, Registry};
19
use anyhow::anyhow;
20
use std::collections::HashMap;
21
use std::hash::Hash;
22
use std::{fmt, str::FromStr};
23
use wasmtime::component::{HasData, Resource, ResourceTable};
24
25
/// Capture the state necessary for calling into the backend ML libraries.
26
pub struct WasiNnCtx {
27
pub(crate) backends: HashMap<GraphEncoding, Backend>,
28
pub(crate) registry: Registry,
29
}
30
31
impl WasiNnCtx {
32
/// Make a new context from the default state.
33
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
34
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
35
Self { backends, registry }
36
}
37
}
38
39
/// A wrapper capturing the needed internal wasi-nn state.
40
///
41
/// Unlike other WASI proposals (see `wasmtime-wasi`, `wasmtime-wasi-http`),
42
/// this wrapper is not a `trait` but rather holds the references directly. This
43
/// remove one layer of abstraction for simplicity only, and could be added back
44
/// in the future if embedders need more control here.
45
pub struct WasiNnView<'a> {
46
ctx: &'a mut WasiNnCtx,
47
table: &'a mut ResourceTable,
48
}
49
50
impl<'a> WasiNnView<'a> {
51
/// Create a new view into the wasi-nn state.
52
pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
53
Self { ctx, table }
54
}
55
}
56
57
/// A wasi-nn error; this appears on the Wasm side as a component model
58
/// resource.
59
#[derive(Debug)]
60
pub struct Error {
61
code: ErrorCode,
62
data: anyhow::Error,
63
}
64
65
/// Construct an [`Error`] resource and immediately return it.
66
///
67
/// The WIT specification currently relies on "errors as resources;" this helper
68
/// macro hides some of that complexity. If [#75] is adopted ("errors as
69
/// records"), this macro is no longer necessary.
70
///
71
/// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75
72
macro_rules! bail {
73
($self:ident, $code:expr, $data:expr) => {
74
let e = Error {
75
code: $code,
76
data: $data.into(),
77
};
78
tracing::error!("failure: {e:?}");
79
let r = $self.table.push(e)?;
80
return Ok(Err(r));
81
};
82
}
83
84
impl From<wasmtime::component::ResourceTableError> for Error {
85
fn from(error: wasmtime::component::ResourceTableError) -> Self {
86
Self {
87
code: ErrorCode::Trap,
88
data: error.into(),
89
}
90
}
91
}
92
93
/// The list of error codes available to the `wasi-nn` API; this should match
94
/// what is specified in WIT.
95
#[derive(Debug)]
96
pub enum ErrorCode {
97
/// Caller module passed an invalid argument.
98
InvalidArgument,
99
/// Invalid encoding.
100
InvalidEncoding,
101
/// The operation timed out.
102
Timeout,
103
/// Runtime error.
104
RuntimeError,
105
/// Unsupported operation.
106
UnsupportedOperation,
107
/// Graph is too large.
108
TooLarge,
109
/// Graph not found.
110
NotFound,
111
/// A runtime error that Wasmtime should trap on; this will not appear in
112
/// the WIT specification.
113
Trap,
114
}
115
116
/// Generate the traits and types from the `wasi-nn` WIT specification.
117
pub(crate) mod generated_ {
118
wasmtime::component::bindgen!({
119
world: "ml",
120
path: "wit/wasi-nn.wit",
121
with: {
122
// Configure all WIT http resources to be defined types in this
123
// crate to use the `ResourceTable` helper methods.
124
"wasi:nn/graph/graph": crate::Graph,
125
"wasi:nn/tensor/tensor": crate::Tensor,
126
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
127
"wasi:nn/errors/error": super::Error,
128
},
129
imports: { default: trappable },
130
trappable_error_type: {
131
"wasi:nn/errors/error" => super::Error,
132
},
133
});
134
}
135
use generated_::wasi::nn::{self as generated}; // Shortcut to the module containing the types we need.
136
137
// Export the `types` used in this crate as well as `ML::add_to_linker`.
138
pub mod types {
139
use super::generated;
140
pub use generated::errors::Error;
141
pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
142
pub use generated::inference::GraphExecutionContext;
143
pub use generated::tensor::{Tensor, TensorType};
144
}
145
pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
146
pub use generated::inference::{GraphExecutionContext, NamedTensor};
147
pub use generated::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
148
pub use generated_::Ml as ML;
149
150
/// Add the WIT-based version of the `wasi-nn` API to a
151
/// [`wasmtime::component::Linker`].
152
pub fn add_to_linker<T: 'static>(
153
l: &mut wasmtime::component::Linker<T>,
154
f: fn(&mut T) -> WasiNnView<'_>,
155
) -> anyhow::Result<()> {
156
generated::graph::add_to_linker::<_, HasWasiNnView>(l, f)?;
157
generated::tensor::add_to_linker::<_, HasWasiNnView>(l, f)?;
158
generated::inference::add_to_linker::<_, HasWasiNnView>(l, f)?;
159
generated::errors::add_to_linker::<_, HasWasiNnView>(l, f)?;
160
Ok(())
161
}
162
163
struct HasWasiNnView;
164
165
impl HasData for HasWasiNnView {
166
type Data<'a> = WasiNnView<'a>;
167
}
168
169
impl generated::graph::Host for WasiNnView<'_> {
170
fn load(
171
&mut self,
172
builders: Vec<GraphBuilder>,
173
encoding: GraphEncoding,
174
target: ExecutionTarget,
175
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
176
tracing::debug!("load {encoding:?} {target:?}");
177
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
178
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
179
match backend.load(&slices, target) {
180
Ok(graph) => {
181
let graph = self.table.push(graph)?;
182
Ok(Ok(graph))
183
}
184
Err(error) => {
185
bail!(self, ErrorCode::RuntimeError, error);
186
}
187
}
188
} else {
189
bail!(
190
self,
191
ErrorCode::InvalidEncoding,
192
anyhow!("unable to find a backend for this encoding")
193
);
194
}
195
}
196
197
fn load_by_name(
198
&mut self,
199
name: String,
200
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
201
use core::result::Result::*;
202
tracing::debug!("load by name {name:?}");
203
let registry = &self.ctx.registry;
204
if let Some(graph) = registry.get(&name) {
205
let graph = graph.clone();
206
let graph = self.table.push(graph)?;
207
Ok(Ok(graph))
208
} else {
209
bail!(
210
self,
211
ErrorCode::NotFound,
212
anyhow!("failed to find graph with name: {name}")
213
);
214
}
215
}
216
}
217
218
impl generated::graph::HostGraph for WasiNnView<'_> {
219
fn init_execution_context(
220
&mut self,
221
graph: Resource<Graph>,
222
) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
223
use core::result::Result::*;
224
tracing::debug!("initialize execution context");
225
let graph = self.table.get(&graph)?;
226
match graph.init_execution_context() {
227
Ok(exec_context) => {
228
let exec_context = self.table.push(exec_context)?;
229
Ok(Ok(exec_context))
230
}
231
Err(error) => {
232
bail!(self, ErrorCode::RuntimeError, error);
233
}
234
}
235
}
236
237
fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
238
self.table.delete(graph)?;
239
Ok(())
240
}
241
}
242
243
impl generated::inference::HostGraphExecutionContext for WasiNnView<'_> {
244
fn compute(
245
&mut self,
246
exec_context: Resource<GraphExecutionContext>,
247
inputs: Vec<NamedTensor>,
248
) -> wasmtime::Result<Result<Vec<NamedTensor>, Resource<Error>>> {
249
tracing::debug!("compute with {} inputs", inputs.len());
250
251
let mut named_tensors = Vec::new();
252
for (name, tensor_resopurce) in inputs.iter() {
253
let tensor = self.table.get(&tensor_resopurce)?;
254
named_tensors.push(crate::backend::NamedTensor {
255
name: name.clone(),
256
tensor: tensor.clone(),
257
});
258
}
259
260
let exec_context = &mut self.table.get_mut(&exec_context)?;
261
262
match exec_context.compute_with_io(named_tensors) {
263
Ok(named_tensors) => {
264
let result = named_tensors
265
.into_iter()
266
.map(|crate::backend::NamedTensor { name, tensor }| {
267
self.table.push(tensor).map(|resource| (name, resource))
268
})
269
.collect();
270
271
match result {
272
Ok(tuples) => Ok(Ok(tuples)),
273
Err(error) => {
274
bail!(self, ErrorCode::RuntimeError, error);
275
}
276
}
277
}
278
Err(error) => {
279
bail!(self, ErrorCode::RuntimeError, error);
280
}
281
}
282
}
283
284
fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
285
self.table.delete(exec_context)?;
286
Ok(())
287
}
288
}
289
290
impl generated::tensor::HostTensor for WasiNnView<'_> {
291
fn new(
292
&mut self,
293
dimensions: TensorDimensions,
294
ty: TensorType,
295
data: TensorData,
296
) -> wasmtime::Result<Resource<Tensor>> {
297
let tensor = Tensor {
298
dimensions,
299
ty,
300
data,
301
};
302
let tensor = self.table.push(tensor)?;
303
Ok(tensor)
304
}
305
306
fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
307
let tensor = self.table.get(&tensor)?;
308
Ok(tensor.dimensions.clone())
309
}
310
311
fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
312
let tensor = self.table.get(&tensor)?;
313
Ok(tensor.ty)
314
}
315
316
fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
317
let tensor = self.table.get(&tensor)?;
318
Ok(tensor.data.clone())
319
}
320
321
fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
322
self.table.delete(tensor)?;
323
Ok(())
324
}
325
}
326
327
impl generated::errors::HostError for WasiNnView<'_> {
328
fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<generated::errors::ErrorCode> {
329
let error = self.table.get(&error)?;
330
match error.code {
331
ErrorCode::InvalidArgument => Ok(generated::errors::ErrorCode::InvalidArgument),
332
ErrorCode::InvalidEncoding => Ok(generated::errors::ErrorCode::InvalidEncoding),
333
ErrorCode::Timeout => Ok(generated::errors::ErrorCode::Timeout),
334
ErrorCode::RuntimeError => Ok(generated::errors::ErrorCode::RuntimeError),
335
ErrorCode::UnsupportedOperation => {
336
Ok(generated::errors::ErrorCode::UnsupportedOperation)
337
}
338
ErrorCode::TooLarge => Ok(generated::errors::ErrorCode::TooLarge),
339
ErrorCode::NotFound => Ok(generated::errors::ErrorCode::NotFound),
340
ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
341
}
342
}
343
344
fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
345
let error = self.table.get(&error)?;
346
Ok(error.data.to_string())
347
}
348
349
fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
350
self.table.delete(error)?;
351
Ok(())
352
}
353
}
354
355
impl generated::errors::Host for WasiNnView<'_> {
356
fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
357
if matches!(err.code, ErrorCode::Trap) {
358
Err(err.data)
359
} else {
360
Ok(err)
361
}
362
}
363
}
364
365
impl generated::tensor::Host for WasiNnView<'_> {}
366
impl generated::inference::Host for WasiNnView<'_> {}
367
368
impl Hash for generated::graph::GraphEncoding {
369
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
370
self.to_string().hash(state)
371
}
372
}
373
374
impl fmt::Display for generated::graph::GraphEncoding {
375
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376
use generated::graph::GraphEncoding::*;
377
match self {
378
Openvino => write!(f, "openvino"),
379
Onnx => write!(f, "onnx"),
380
Pytorch => write!(f, "pytorch"),
381
Tensorflow => write!(f, "tensorflow"),
382
Tensorflowlite => write!(f, "tensorflowlite"),
383
Autodetect => write!(f, "autodetect"),
384
Ggml => write!(f, "ggml"),
385
}
386
}
387
}
388
389
impl FromStr for generated::graph::GraphEncoding {
390
type Err = GraphEncodingParseError;
391
fn from_str(s: &str) -> Result<Self, Self::Err> {
392
match s.to_lowercase().as_str() {
393
"openvino" => Ok(generated::graph::GraphEncoding::Openvino),
394
"onnx" => Ok(generated::graph::GraphEncoding::Onnx),
395
"pytorch" => Ok(generated::graph::GraphEncoding::Pytorch),
396
"tensorflow" => Ok(generated::graph::GraphEncoding::Tensorflow),
397
"tensorflowlite" => Ok(generated::graph::GraphEncoding::Tensorflowlite),
398
"autodetect" => Ok(generated::graph::GraphEncoding::Autodetect),
399
_ => Err(GraphEncodingParseError(s.into())),
400
}
401
}
402
}
403
#[derive(Debug)]
404
pub struct GraphEncodingParseError(String);
405
impl fmt::Display for GraphEncodingParseError {
406
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407
write!(f, "unknown graph encoding: {}", self.0)
408
}
409
}
410
impl std::error::Error for GraphEncodingParseError {}
411
412