Path: blob/main/crates/wasi-nn/examples/classification-component-onnx/src/main.rs
3101 views
#![allow(unused_braces)]1use image::ImageReader;2use image::{DynamicImage, RgbImage};3use ndarray::{Array, Dim};4use std::fs;5use std::io::BufRead;67const IMG_PATH: &str = "fixture/images/dog.jpg";89wit_bindgen::generate!({10path: "../../wit",11world: "ml",12});1314use self::wasi::nn::{15graph::{Graph, GraphBuilder, load, ExecutionTarget, GraphEncoding},16tensor::{Tensor, TensorData, TensorDimensions, TensorType},17};1819/// Determine execution target from command-line argument20/// Usage: wasm_module [cpu|gpu|cuda]21fn get_execution_target() -> ExecutionTarget {22let args: Vec<String> = std::env::args().collect();2324// First argument (index 0) is the program name, second (index 1) is the target25// Ignore any arguments after index 126if args.len() >= 2 {27match args[1].to_lowercase().as_str() {28"gpu" | "cuda" => {29println!("Using GPU (CUDA) execution target from argument");30return ExecutionTarget::Gpu;31}32"cpu" => {33println!("Using CPU execution target from argument");34return ExecutionTarget::Cpu;35}36_ => {37println!("Unknown execution target '{}', defaulting to CPU", args[1]);38}39}40} else {41println!("No execution target specified, defaulting to CPU");42println!("Usage: <program> [cpu|gpu|cuda]");43}4445ExecutionTarget::Cpu46}4748fn main() {49// Load the ONNX model - SqueezeNet 1.1-750// Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet51let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap();52println!("Read ONNX model, size in bytes: {}", model.len());5354// Determine execution target55let execution_target = get_execution_target();5657let graph = load(&[model], GraphEncoding::Onnx, execution_target).unwrap();58println!("Loaded graph into wasi-nn with {:?} target", execution_target);5960let exec_context = Graph::init_execution_context(&graph).unwrap();61println!("Created wasi-nn execution context.");6263// Load SquezeNet 1000 labels used for classification64let labels = fs::read("fixture/labels/squeezenet1.1-7.txt").unwrap();65let class_labels: Vec<String> = labels.lines().map(|line| line.unwrap()).collect();66println!("Read ONNX Labels, # of labels: {}", class_labels.len());6768// Prepare WASI-NN tensor - Tensor data is always a bytes vector69let dimensions: TensorDimensions = vec![1, 3, 224, 224];70let data: TensorData = image_to_tensor(IMG_PATH.to_string(), 224, 224);71let tensor = Tensor::new(72&dimensions,73TensorType::Fp32,74&data,75);76let input_tensor = vec!(("data".to_string(), tensor));77// Execute the inferencing78let output_tensor_vec = exec_context.compute(input_tensor).unwrap();79println!("Executed graph inference");8081let output_tensor = output_tensor_vec.iter().find_map(|(tensor_name, tensor)| {82if tensor_name == "squeezenet0_flatten0_reshape0" {83Some(tensor)84} else {85None86}87});88let output_data = output_tensor.expect("No output tensor").data();8990println!("Retrieved output data with length: {}", output_data.len());91let output_f32 = bytes_to_f32_vec(output_data);9293let output_shape = [1, 1000, 1, 1];94let output_tensor = Array::from_shape_vec(output_shape, output_f32).unwrap();9596// Post-Processing requirement: compute softmax to inferencing output97let exp_output = output_tensor.mapv(|x| x.exp());98let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));99let softmax_output = exp_output / &sum_exp_output;100101let mut sorted = softmax_output102.axis_iter(ndarray::Axis(1))103.enumerate()104.into_iter()105.map(|(i, v)| (i, v[Dim([0, 0, 0])]))106.collect::<Vec<(_, _)>>();107sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());108109for (index, probability) in sorted.iter().take(3) {110println!(111"Index: {} - Probability: {}",112class_labels[*index], probability113);114}115}116117pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {118let chunks: Vec<&[u8]> = data.chunks(4).collect();119let v: Vec<f32> = chunks120.into_iter()121.map(|c| f32::from_le_bytes(c.try_into().unwrap()))122.collect();123124v.into_iter().collect()125}126127// Take the image located at 'path', open it, resize it to height x width, and then converts128// the pixel precision to FP32. The resulting BGR pixel vector is then returned.129fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> {130let pixels = ImageReader::open(path).unwrap().decode().unwrap();131let dyn_img: DynamicImage = pixels.resize_exact(width, height, image::imageops::Triangle);132let bgr_img: RgbImage = dyn_img.to_rgb8();133134// Get an array of the pixel values135let raw_u8_arr: &[u8] = &bgr_img.as_raw()[..];136137// Create an array to hold the f32 value of those pixels138let bytes_required = raw_u8_arr.len() * 4;139let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];140141// Normalizing values for the model142let mean = [0.485, 0.456, 0.406];143let std = [0.229, 0.224, 0.225];144145// Read the number as a f32 and break it into u8 bytes146for i in 0..raw_u8_arr.len() {147let u8_f32: f32 = raw_u8_arr[i] as f32;148let rgb_iter = i % 3;149150// Normalize the pixel151let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];152153// Convert it to u8 bytes and write it with new shape154let u8_bytes = norm_u8_f32.to_ne_bytes();155for j in 0..4 {156u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];157}158}159160return u8_f32_arr;161}162163164