Path: blob/main/crates/wasi-nn/examples/classification-component-onnx/src/main.rs
1692 views
#![allow(unused_braces)]1use image::ImageReader;2use image::{DynamicImage, RgbImage};3use ndarray::{Array, Dim};4use std::fs;5use std::io::BufRead;67const IMG_PATH: &str = "fixture/images/dog.jpg";89wit_bindgen::generate!({10path: "../../wit",11world: "ml",12});1314use self::wasi::nn::{15graph::{Graph, GraphBuilder, load, ExecutionTarget, GraphEncoding},16tensor::{Tensor, TensorData, TensorDimensions, TensorType},17};1819fn main() {20// Load the ONNX model - SqueezeNet 1.1-721// Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet22let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap();23println!("Read ONNX model, size in bytes: {}", model.len());2425let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap();26println!("Loaded graph into wasi-nn");2728let exec_context = Graph::init_execution_context(&graph).unwrap();29println!("Created wasi-nn execution context.");3031// Load SquezeNet 1000 labels used for classification32let labels = fs::read("fixture/labels/squeezenet1.1-7.txt").unwrap();33let class_labels: Vec<String> = labels.lines().map(|line| line.unwrap()).collect();34println!("Read ONNX Labels, # of labels: {}", class_labels.len());3536// Prepare WASI-NN tensor - Tensor data is always a bytes vector37let dimensions: TensorDimensions = vec![1, 3, 224, 224];38let data: TensorData = image_to_tensor(IMG_PATH.to_string(), 224, 224);39let tensor = Tensor::new(40&dimensions,41TensorType::Fp32,42&data,43);44let input_tensor = vec!(("data".to_string(), tensor));45// Execute the inferencing46let output_tensor_vec = exec_context.compute(input_tensor).unwrap();47println!("Executed graph inference");4849let output_tensor = output_tensor_vec.iter().find_map(|(tensor_name, tensor)| {50if tensor_name == "squeezenet0_flatten0_reshape0" {51Some(tensor)52} else {53None54}55});56let output_data = output_tensor.expect("No output tensor").data();5758println!("Retrieved output data with length: {}", output_data.len());59let output_f32 = bytes_to_f32_vec(output_data);6061let output_shape = [1, 1000, 1, 1];62let output_tensor = Array::from_shape_vec(output_shape, output_f32).unwrap();6364// Post-Processing requirement: compute softmax to inferencing output65let exp_output = output_tensor.mapv(|x| x.exp());66let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));67let softmax_output = exp_output / &sum_exp_output;6869let mut sorted = softmax_output70.axis_iter(ndarray::Axis(1))71.enumerate()72.into_iter()73.map(|(i, v)| (i, v[Dim([0, 0, 0])]))74.collect::<Vec<(_, _)>>();75sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());7677for (index, probability) in sorted.iter().take(3) {78println!(79"Index: {} - Probability: {}",80class_labels[*index], probability81);82}83}8485pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {86let chunks: Vec<&[u8]> = data.chunks(4).collect();87let v: Vec<f32> = chunks88.into_iter()89.map(|c| f32::from_le_bytes(c.try_into().unwrap()))90.collect();9192v.into_iter().collect()93}9495// Take the image located at 'path', open it, resize it to height x width, and then converts96// the pixel precision to FP32. The resulting BGR pixel vector is then returned.97fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> {98let pixels = ImageReader::open(path).unwrap().decode().unwrap();99let dyn_img: DynamicImage = pixels.resize_exact(width, height, image::imageops::Triangle);100let bgr_img: RgbImage = dyn_img.to_rgb8();101102// Get an array of the pixel values103let raw_u8_arr: &[u8] = &bgr_img.as_raw()[..];104105// Create an array to hold the f32 value of those pixels106let bytes_required = raw_u8_arr.len() * 4;107let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];108109// Normalizing values for the model110let mean = [0.485, 0.456, 0.406];111let std = [0.229, 0.224, 0.225];112113// Read the number as a f32 and break it into u8 bytes114for i in 0..raw_u8_arr.len() {115let u8_f32: f32 = raw_u8_arr[i] as f32;116let rgb_iter = i % 3;117118// Normalize the pixel119let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];120121// Convert it to u8 bytes and write it with new shape122let u8_bytes = norm_u8_f32.to_ne_bytes();123for j in 0..4 {124u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];125}126}127128return u8_f32_arr;129}130131132