Path: blob/main/crates/wasi-nn/examples/classification-example-winml/src/main.rs
2459 views
use image::{DynamicImage, RgbImage};1use ndarray::Array;2use std::{fs, time::Instant};34pub fn main() {5// Load model from a file.6let graph =7wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Onnx, wasi_nn::ExecutionTarget::CPU)8.build_from_files(["fixture/mobilenet.onnx"])9.unwrap();1011let mut context = graph.init_execution_context().unwrap();12println!("Created an execution context.");1314// Read image from file and convert it to tensor data.15let image_data = fs::read("fixture/kitten.png").unwrap();1617// Preprocessing. Normalize data based on model requirements https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet#preprocessing18let tensor_data = preprocess(19image_data.as_slice(),20224,21224,22&[0.485, 0.456, 0.406],23&[0.229, 0.224, 0.225],24);25println!("Read input tensor, size in bytes: {}", tensor_data.len());2627context28.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &tensor_data)29.unwrap();3031// Execute the inference.32let before_compute = Instant::now();33context.compute().unwrap();34println!(35"Executed graph inference, took {} ms.",36before_compute.elapsed().as_millis()37);3839// Retrieve the output.40let mut output_buffer = vec![0f32; 1000];41context.get_output(0, &mut output_buffer[..]).unwrap();4243// Postprocessing. Calculating the softmax probability scores.44let result = postprocess(output_buffer);4546// Load labels for classification47let labels_file = fs::read("fixture/synset.txt").unwrap();48let labels_str = String::from_utf8(labels_file).unwrap();49let labels: Vec<String> = labels_str50.lines()51.map(|line| {52let words: Vec<&str> = line.split_whitespace().collect();53words[1..].join(" ")54})55.collect();5657println!(58"Found results, sorted top 5: {:?}",59&sort_results(&result, &labels)[..5]60)61}6263// Sort the buffer of probabilities. The graph places the match probability for each class at the64// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert65// to a wrapping InferenceResult and sort the results.66fn sort_results(buffer: &[f32], labels: &Vec<String>) -> Vec<InferenceResult> {67let mut results: Vec<InferenceResult> = buffer68.iter()69.enumerate()70.map(|(c, p)| InferenceResult(labels[c].clone(), *p))71.collect();72results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());73results74}7576// Resize image to height x width, and then converts the pixel precision to FP32, normalize with77// given mean and std. The resulting RGB pixel vector is then returned.78fn preprocess(image: &[u8], height: u32, width: u32, mean: &[f32], std: &[f32]) -> Vec<u8> {79let dyn_img: DynamicImage = image::load_from_memory(image).unwrap().resize_exact(80width,81height,82image::imageops::Triangle,83);84let rgb_img: RgbImage = dyn_img.to_rgb8();8586// Get an array of the pixel values87let raw_u8_arr: &[u8] = &rgb_img.as_raw()[..];8889// Create an array to hold the f32 value of those pixels90let bytes_required = raw_u8_arr.len() * 4;91let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];9293// Read the number as a f32 and break it into u8 bytes94for i in 0..raw_u8_arr.len() {95let u8_f32: f32 = raw_u8_arr[i] as f32;96let rgb_iter = i % 3;9798// Normalize the pixel99let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];100101// Convert it to u8 bytes and write it with new shape102let u8_bytes = norm_u8_f32.to_ne_bytes();103for j in 0..4 {104u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];105}106}107108return u8_f32_arr;109}110111fn postprocess(output_tensor: Vec<f32>) -> Vec<f32> {112// Post-Processing requirement: compute softmax to inferencing output113let output_shape = [1, 1000, 1, 1];114let exp_output = Array::from_shape_vec(output_shape, output_tensor)115.unwrap()116.mapv(|x| x.exp());117let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));118let softmax_output = exp_output / &sum_exp_output;119softmax_output.into_raw_vec()120}121122pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {123let chunks: Vec<&[u8]> = data.chunks(4).collect();124let v: Vec<f32> = chunks125.into_iter()126.map(|c| f32::from_ne_bytes(c.try_into().unwrap()))127.collect();128129v.into_iter().collect()130}131132// A wrapper for class ID and match probabilities.133#[derive(Debug, PartialEq)]134struct InferenceResult(String, f32);135136137