Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/test-programs/src/nn.rs
1693 views
1
//! This module attempts to paper over the differences between the two
2
//! implementations of wasi-nn: the legacy WITX-based version (`mod witx`) and
3
//! the up-to-date WIT version (`mod wit`). Since the tests are mainly a simple
4
//! classifier, this exposes a high-level `classify` function to go along with
5
//! `load`, etc.
6
//!
7
//! This module exists solely for convenience--e.g., reduces test duplication.
8
//! In the future can be safely disposed of or altered as more tests are added.
9
10
/// Call `wasi-nn` functions from WebAssembly using the canonical ABI of the
11
/// component model via WIT-based tooling. Used by `bin/nn_wit_*.rs` tests.
12
pub mod wit {
13
use anyhow::{Result, anyhow};
14
use std::time::Instant;
15
16
// Generate the wasi-nn bindings based on the `*.wit` files.
17
wit_bindgen::generate!({
18
path: "../wasi-nn/wit",
19
world: "ml",
20
default_bindings_module: "test_programs::ml"
21
});
22
use self::wasi::nn::errors;
23
use self::wasi::nn::graph::{self, Graph};
24
pub use self::wasi::nn::graph::{ExecutionTarget, GraphEncoding}; // Used by tests.
25
use self::wasi::nn::tensor::{Tensor, TensorType};
26
27
/// Load a wasi-nn graph from a set of bytes.
28
pub fn load(
29
bytes: &[Vec<u8>],
30
encoding: GraphEncoding,
31
target: ExecutionTarget,
32
) -> Result<Graph> {
33
graph::load(bytes, encoding, target).map_err(err_as_anyhow)
34
}
35
36
/// Load a wasi-nn graph by name.
37
pub fn load_by_name(name: &str) -> Result<Graph> {
38
graph::load_by_name(name).map_err(err_as_anyhow)
39
}
40
41
/// Run a wasi-nn inference using a simple classifier model (single input,
42
/// single output).
43
pub fn classify(graph: Graph, input: (&str, Vec<u8>)) -> Result<Vec<f32>> {
44
let context = graph.init_execution_context().map_err(err_as_anyhow)?;
45
println!("[nn] created wasi-nn execution context with ID: {context:?}");
46
47
// Many classifiers have a single input; currently, this test suite also
48
// uses tensors of the same shape, though this is not usually the case.
49
let tensor = Tensor::new(&vec![1, 3, 224, 224], TensorType::Fp32, &input.1);
50
println!("[nn] input tensor: {} bytes", input.1.len());
51
52
let before = Instant::now();
53
let input_tuple = (input.0.to_string(), tensor);
54
let output_tensors = context.compute(vec![input_tuple]).unwrap();
55
println!(
56
"[nn] executed graph inference in {} ms",
57
before.elapsed().as_millis()
58
);
59
60
// Many classifiers emit probabilities as floating point values; here we
61
// convert the raw bytes to `f32` knowing all models used here use that
62
// type.
63
let output = &output_tensors[0].1;
64
println!(
65
"[nn] retrieved output tensor: {} bytes",
66
output.data().len()
67
);
68
let output: Vec<f32> = output
69
.data()
70
.chunks(4)
71
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
72
.collect();
73
Ok(output)
74
}
75
76
fn err_as_anyhow(e: errors::Error) -> anyhow::Error {
77
anyhow!("error: {e:?}")
78
}
79
}
80
81
/// Call `wasi-nn` functions from WebAssembly using the legacy WITX-based
82
/// tooling. This older API has been deprecated for the newer WIT-based API but
83
/// retained for backwards compatibility testing--i.e., `bin/nn_witx_*.rs`
84
/// tests.
85
pub mod witx {
86
use anyhow::Result;
87
use std::time::Instant;
88
pub use wasi_nn::{ExecutionTarget, GraphEncoding};
89
use wasi_nn::{Graph, GraphBuilder, TensorType};
90
91
/// Load a wasi-nn graph from a set of bytes.
92
pub fn load(
93
bytes: &[&[u8]],
94
encoding: GraphEncoding,
95
target: ExecutionTarget,
96
) -> Result<Graph> {
97
Ok(GraphBuilder::new(encoding, target).build_from_bytes(bytes)?)
98
}
99
100
/// Load a wasi-nn graph by name.
101
pub fn load_by_name(
102
name: &str,
103
encoding: GraphEncoding,
104
target: ExecutionTarget,
105
) -> Result<Graph> {
106
Ok(GraphBuilder::new(encoding, target).build_from_cache(name)?)
107
}
108
109
/// Run a wasi-nn inference using a simple classifier model (single input,
110
/// single output).
111
pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {
112
let mut context = graph.init_execution_context()?;
113
println!("[nn] created wasi-nn execution context with ID: {context}");
114
115
// Many classifiers have a single input; currently, this test suite also
116
// uses tensors of the same shape, though this is not usually the case.
117
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?;
118
println!("[nn] set input tensor: {} bytes", tensor.len());
119
120
let before = Instant::now();
121
context.compute()?;
122
println!(
123
"[nn] executed graph inference in {} ms",
124
before.elapsed().as_millis()
125
);
126
127
// Many classifiers emit probabilities as floating point values; here we
128
// convert the raw bytes to `f32` knowing all models used here use that
129
// type.
130
let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::<f32>()];
131
let num_bytes = context.get_output(0, &mut output_buffer)?;
132
println!("[nn] retrieved output tensor: {num_bytes} bytes");
133
let output: Vec<f32> = output_buffer[..num_bytes]
134
.chunks(4)
135
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
136
.collect();
137
Ok(output)
138
}
139
}
140
141
/// Sort some classification probabilities.
142
///
143
/// Many classification models output a buffer of probabilities for each class,
144
/// placing the match probability for each class at the index for that class
145
/// (the probability of class `N` is stored at `probabilities[N]`).
146
pub fn sort_results(probabilities: &[f32]) -> Vec<InferenceResult> {
147
let mut results: Vec<InferenceResult> = probabilities
148
.iter()
149
.enumerate()
150
.map(|(c, p)| InferenceResult(c, *p))
151
.collect();
152
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
153
results
154
}
155
156
// A wrapper for class ID and match probabilities.
157
#[derive(Debug, PartialEq)]
158
pub struct InferenceResult(usize, f32);
159
impl InferenceResult {
160
pub fn class_id(&self) -> usize {
161
self.0
162
}
163
pub fn probability(&self) -> f32 {
164
self.1
165
}
166
}
167
168