Path: blob/main/crates/wasi-nn/examples/classification-example-named/src/main.rs
2459 views
use std::fs;1use wasi_nn::*;23pub fn main() {4let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)5.build_from_cache("mobilenet")6.unwrap();7println!("Loaded a graph: {:?}", graph);89let mut context = graph.init_execution_context().unwrap();10println!("Created an execution context: {:?}", context);1112// Load a tensor that precisely matches the graph input tensor (see13// `fixture/frozen_inference_graph.xml`).14let tensor_data = fs::read("fixture/tensor.bgr").unwrap();15println!("Read input tensor, size in bytes: {}", tensor_data.len());16context17.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)18.unwrap();1920// Execute the inference.21context.compute().unwrap();22println!("Executed graph inference");2324// Retrieve the output.25let mut output_buffer = vec![0f32; 1001];26context.get_output(0, &mut output_buffer[..]).unwrap();2728println!(29"Found results, sorted top 5: {:?}",30&sort_results(&output_buffer)[..5]31)32}3334// Sort the buffer of probabilities. The graph places the match probability for each class at the35// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert36// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output37// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense38// (e.g. 763 = "revolver" vs 762 = "restaurant")39fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {40let mut results: Vec<InferenceResult> = buffer41.iter()42.skip(1)43.enumerate()44.map(|(c, p)| InferenceResult(c, *p))45.collect();46results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());47results48}4950// A wrapper for class ID and match probabilities.51#[derive(Debug, PartialEq)]52struct InferenceResult(usize, f32);535455