Path: blob/main/crates/wasi-nn/examples/classification-example/src/main.rs
3079 views
use std::convert::TryInto;1use std::fs;23pub fn main() {4let xml = fs::read_to_string("fixture/model.xml").unwrap();5println!("Read graph XML, first 50 characters: {}", &xml[..50]);67let weights = fs::read("fixture/model.bin").unwrap();8println!("Read graph weights, size in bytes: {}", weights.len());910let graph = unsafe {11wasi_nn::load(12&[&xml.into_bytes(), &weights],13wasi_nn::GRAPH_ENCODING_OPENVINO,14wasi_nn::EXECUTION_TARGET_CPU,15)16.unwrap()17};18println!("Loaded graph into wasi-nn with ID: {graph}");1920let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };21println!("Created wasi-nn execution context with ID: {context}");2223// Load a tensor that precisely matches the graph input tensor (see24// `fixture/frozen_inference_graph.xml`).25let tensor_data = fs::read("fixture/tensor.bgr").unwrap();26println!("Read input tensor, size in bytes: {}", tensor_data.len());27let tensor = wasi_nn::Tensor {28dimensions: &[1, 3, 224, 224],29r#type: wasi_nn::TENSOR_TYPE_F32,30data: &tensor_data,31};32unsafe {33wasi_nn::set_input(context, 0, tensor).unwrap();34}3536// Execute the inference.37unsafe {38wasi_nn::compute(context).unwrap();39}40println!("Executed graph inference");4142// Retrieve the output.43let mut output_buffer = vec![0f32; 1001];44unsafe {45wasi_nn::get_output(46context,470,48&mut output_buffer[..] as *mut [f32] as *mut u8,49(output_buffer.len() * 4).try_into().unwrap(),50)51.unwrap();52}53println!(54"Found results, sorted top 5: {:?}",55&sort_results(&output_buffer)[..5]56)57}5859// Sort the buffer of probabilities. The graph places the match probability for each class at the60// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert61// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output62// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense63// (e.g. 763 = "revolver" vs 762 = "restaurant")64fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {65let mut results: Vec<InferenceResult> = buffer66.iter()67.skip(1)68.enumerate()69.map(|(c, p)| InferenceResult(c, *p))70.collect();71results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());72results73}7475// A wrapper for class ID and match probabilities.76#[derive(Debug, PartialEq)]77struct InferenceResult(usize, f32);787980