Path: blob/main/crates/wasi-nn/examples/classification-example-pytorch/src/main.rs
1692 views
use anyhow::{Error, Result};1use image::{DynamicImage, RgbImage};2use std::fs;3use wasi_nn::{self, ExecutionTarget, GraphBuilder, GraphEncoding};45pub fn main() -> Result<(), Error> {6// Read the model file (Resnet18)7let model = fs::read("fixture/model.pt")?;8let graph = GraphBuilder::new(GraphEncoding::Pytorch, ExecutionTarget::CPU)9.build_from_bytes(&[&model])?;1011let mut context = graph.init_execution_context()?;1213let image = fs::read("fixture/kitten.png")?;14// Preprocessing. Normalize data based on model requirements https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet#preprocessing15let tensor_data = preprocess(16image.as_slice(),17224,18224,19&[0.485, 0.456, 0.406],20&[0.229, 0.224, 0.225],21);22let precision = wasi_nn::TensorType::F32;23// Resnet18 model input is NCHW24let shape = &[1, 3, 224, 224];25// Set the input tensor. PyTorch models do not use ports, so it is set to 0 here.26// Tensors are passed to the model, and the model's forward method processes these tensors.27context.set_input(0, precision, shape, &tensor_data)?;28context.compute()?;29let mut output_buffer = vec![0f32; 1000];30context.get_output(0, &mut output_buffer[..])?;31let result = softmax(output_buffer);32println!(33"Found results, sorted top 5: {:?}",34&sort_results(&result)[..5]35);36Ok(())37}3839// Resize image to height x width, and then converts the pixel precision to FP32, normalize with40// given mean and std. The resulting RGB pixel vector is then returned.41fn preprocess(image: &[u8], height: u32, width: u32, mean: &[f32], std: &[f32]) -> Vec<u8> {42let dyn_img: DynamicImage = image::load_from_memory(image).unwrap().resize_exact(43width,44height,45image::imageops::Triangle,46);47let rgb_img: RgbImage = dyn_img.to_rgb8();4849// Get an array of the pixel values50let raw_u8_arr: &[u8] = &rgb_img.as_raw()[..];5152// Create an array to hold the f32 value of those pixels53let bytes_required = raw_u8_arr.len() * 4;54let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];5556// Read the number as a f32 and break it into u8 bytes57for i in 0..raw_u8_arr.len() {58let u8_f32: f32 = raw_u8_arr[i] as f32;59let rgb_iter = i % 3;6061// Normalize the pixel62let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];6364// Convert it to u8 bytes and write it with new shape65let u8_bytes = norm_u8_f32.to_ne_bytes();66for j in 0..4 {67u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];68}69}70u8_f32_arr71}7273fn softmax(output_tensor: Vec<f32>) -> Vec<f32> {74let max_val = output_tensor75.iter()76.cloned()77.fold(f32::NEG_INFINITY, f32::max);7879// Compute the exponential of each element subtracted by max_val for numerical stability.80let exps: Vec<f32> = output_tensor.iter().map(|&x| (x - max_val).exp()).collect();8182// Compute the sum of the exponentials.83let sum_exps: f32 = exps.iter().sum();8485// Normalize each element to get the probabilities.86exps.iter().map(|&exp| exp / sum_exps).collect()87}8889fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {90let mut results: Vec<InferenceResult> = buffer91.iter()92.enumerate()93.map(|(c, p)| InferenceResult(c, *p))94.collect();95results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());96results97}9899#[derive(Debug, PartialEq)]100struct InferenceResult(usize, f32);101102103