Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/examples/classification-example-named/src/main.rs
2459 views
1
use std::fs;
2
use wasi_nn::*;
3
4
pub fn main() {
5
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
6
.build_from_cache("mobilenet")
7
.unwrap();
8
println!("Loaded a graph: {:?}", graph);
9
10
let mut context = graph.init_execution_context().unwrap();
11
println!("Created an execution context: {:?}", context);
12
13
// Load a tensor that precisely matches the graph input tensor (see
14
// `fixture/frozen_inference_graph.xml`).
15
let tensor_data = fs::read("fixture/tensor.bgr").unwrap();
16
println!("Read input tensor, size in bytes: {}", tensor_data.len());
17
context
18
.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)
19
.unwrap();
20
21
// Execute the inference.
22
context.compute().unwrap();
23
println!("Executed graph inference");
24
25
// Retrieve the output.
26
let mut output_buffer = vec![0f32; 1001];
27
context.get_output(0, &mut output_buffer[..]).unwrap();
28
29
println!(
30
"Found results, sorted top 5: {:?}",
31
&sort_results(&output_buffer)[..5]
32
)
33
}
34
35
// Sort the buffer of probabilities. The graph places the match probability for each class at the
36
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
37
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
38
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
39
// (e.g. 763 = "revolver" vs 762 = "restaurant")
40
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
41
let mut results: Vec<InferenceResult> = buffer
42
.iter()
43
.skip(1)
44
.enumerate()
45
.map(|(c, p)| InferenceResult(c, *p))
46
.collect();
47
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
48
results
49
}
50
51
// A wrapper for class ID and match probabilities.
52
#[derive(Debug, PartialEq)]
53
struct InferenceResult(usize, f32);
54
55