Path: blob/main/crates/wasi-nn/examples/classification-example/src/main.rs
1692 views
use std::convert::TryInto;1use std::fs;2use wasi_nn;34pub fn main() {5let xml = fs::read_to_string("fixture/model.xml").unwrap();6println!("Read graph XML, first 50 characters: {}", &xml[..50]);78let weights = fs::read("fixture/model.bin").unwrap();9println!("Read graph weights, size in bytes: {}", weights.len());1011let graph = unsafe {12wasi_nn::load(13&[&xml.into_bytes(), &weights],14wasi_nn::GRAPH_ENCODING_OPENVINO,15wasi_nn::EXECUTION_TARGET_CPU,16)17.unwrap()18};19println!("Loaded graph into wasi-nn with ID: {}", graph);2021let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };22println!("Created wasi-nn execution context with ID: {}", context);2324// Load a tensor that precisely matches the graph input tensor (see25// `fixture/frozen_inference_graph.xml`).26let tensor_data = fs::read("fixture/tensor.bgr").unwrap();27println!("Read input tensor, size in bytes: {}", tensor_data.len());28let tensor = wasi_nn::Tensor {29dimensions: &[1, 3, 224, 224],30r#type: wasi_nn::TENSOR_TYPE_F32,31data: &tensor_data,32};33unsafe {34wasi_nn::set_input(context, 0, tensor).unwrap();35}3637// Execute the inference.38unsafe {39wasi_nn::compute(context).unwrap();40}41println!("Executed graph inference");4243// Retrieve the output.44let mut output_buffer = vec![0f32; 1001];45unsafe {46wasi_nn::get_output(47context,480,49&mut output_buffer[..] as *mut [f32] as *mut u8,50(output_buffer.len() * 4).try_into().unwrap(),51)52.unwrap();53}54println!(55"Found results, sorted top 5: {:?}",56&sort_results(&output_buffer)[..5]57)58}5960// Sort the buffer of probabilities. The graph places the match probability for each class at the61// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert62// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output63// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense64// (e.g. 763 = "revolver" vs 762 = "restaurant")65fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {66let mut results: Vec<InferenceResult> = buffer67.iter()68.skip(1)69.enumerate()70.map(|(c, p)| InferenceResult(c, *p))71.collect();72results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());73results74}7576// A wrapper for class ID and match probabilities.77#[derive(Debug, PartialEq)]78struct InferenceResult(usize, f32);798081