Path: blob/main/crates/test-programs/src/nn.rs
1693 views
//! This module attempts to paper over the differences between the two1//! implementations of wasi-nn: the legacy WITX-based version (`mod witx`) and2//! the up-to-date WIT version (`mod wit`). Since the tests are mainly a simple3//! classifier, this exposes a high-level `classify` function to go along with4//! `load`, etc.5//!6//! This module exists solely for convenience--e.g., reduces test duplication.7//! In the future can be safely disposed of or altered as more tests are added.89/// Call `wasi-nn` functions from WebAssembly using the canonical ABI of the10/// component model via WIT-based tooling. Used by `bin/nn_wit_*.rs` tests.11pub mod wit {12use anyhow::{Result, anyhow};13use std::time::Instant;1415// Generate the wasi-nn bindings based on the `*.wit` files.16wit_bindgen::generate!({17path: "../wasi-nn/wit",18world: "ml",19default_bindings_module: "test_programs::ml"20});21use self::wasi::nn::errors;22use self::wasi::nn::graph::{self, Graph};23pub use self::wasi::nn::graph::{ExecutionTarget, GraphEncoding}; // Used by tests.24use self::wasi::nn::tensor::{Tensor, TensorType};2526/// Load a wasi-nn graph from a set of bytes.27pub fn load(28bytes: &[Vec<u8>],29encoding: GraphEncoding,30target: ExecutionTarget,31) -> Result<Graph> {32graph::load(bytes, encoding, target).map_err(err_as_anyhow)33}3435/// Load a wasi-nn graph by name.36pub fn load_by_name(name: &str) -> Result<Graph> {37graph::load_by_name(name).map_err(err_as_anyhow)38}3940/// Run a wasi-nn inference using a simple classifier model (single input,41/// single output).42pub fn classify(graph: Graph, input: (&str, Vec<u8>)) -> Result<Vec<f32>> {43let context = graph.init_execution_context().map_err(err_as_anyhow)?;44println!("[nn] created wasi-nn execution context with ID: {context:?}");4546// Many classifiers have a single input; currently, this test suite also47// uses tensors of the same shape, though this is not usually the case.48let tensor = Tensor::new(&vec![1, 3, 224, 224], TensorType::Fp32, &input.1);49println!("[nn] input tensor: {} bytes", input.1.len());5051let before = Instant::now();52let input_tuple = (input.0.to_string(), tensor);53let output_tensors = context.compute(vec![input_tuple]).unwrap();54println!(55"[nn] executed graph inference in {} ms",56before.elapsed().as_millis()57);5859// Many classifiers emit probabilities as floating point values; here we60// convert the raw bytes to `f32` knowing all models used here use that61// type.62let output = &output_tensors[0].1;63println!(64"[nn] retrieved output tensor: {} bytes",65output.data().len()66);67let output: Vec<f32> = output68.data()69.chunks(4)70.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))71.collect();72Ok(output)73}7475fn err_as_anyhow(e: errors::Error) -> anyhow::Error {76anyhow!("error: {e:?}")77}78}7980/// Call `wasi-nn` functions from WebAssembly using the legacy WITX-based81/// tooling. This older API has been deprecated for the newer WIT-based API but82/// retained for backwards compatibility testing--i.e., `bin/nn_witx_*.rs`83/// tests.84pub mod witx {85use anyhow::Result;86use std::time::Instant;87pub use wasi_nn::{ExecutionTarget, GraphEncoding};88use wasi_nn::{Graph, GraphBuilder, TensorType};8990/// Load a wasi-nn graph from a set of bytes.91pub fn load(92bytes: &[&[u8]],93encoding: GraphEncoding,94target: ExecutionTarget,95) -> Result<Graph> {96Ok(GraphBuilder::new(encoding, target).build_from_bytes(bytes)?)97}9899/// Load a wasi-nn graph by name.100pub fn load_by_name(101name: &str,102encoding: GraphEncoding,103target: ExecutionTarget,104) -> Result<Graph> {105Ok(GraphBuilder::new(encoding, target).build_from_cache(name)?)106}107108/// Run a wasi-nn inference using a simple classifier model (single input,109/// single output).110pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {111let mut context = graph.init_execution_context()?;112println!("[nn] created wasi-nn execution context with ID: {context}");113114// Many classifiers have a single input; currently, this test suite also115// uses tensors of the same shape, though this is not usually the case.116context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?;117println!("[nn] set input tensor: {} bytes", tensor.len());118119let before = Instant::now();120context.compute()?;121println!(122"[nn] executed graph inference in {} ms",123before.elapsed().as_millis()124);125126// Many classifiers emit probabilities as floating point values; here we127// convert the raw bytes to `f32` knowing all models used here use that128// type.129let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::<f32>()];130let num_bytes = context.get_output(0, &mut output_buffer)?;131println!("[nn] retrieved output tensor: {num_bytes} bytes");132let output: Vec<f32> = output_buffer[..num_bytes]133.chunks(4)134.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))135.collect();136Ok(output)137}138}139140/// Sort some classification probabilities.141///142/// Many classification models output a buffer of probabilities for each class,143/// placing the match probability for each class at the index for that class144/// (the probability of class `N` is stored at `probabilities[N]`).145pub fn sort_results(probabilities: &[f32]) -> Vec<InferenceResult> {146let mut results: Vec<InferenceResult> = probabilities147.iter()148.enumerate()149.map(|(c, p)| InferenceResult(c, *p))150.collect();151results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());152results153}154155// A wrapper for class ID and match probabilities.156#[derive(Debug, PartialEq)]157pub struct InferenceResult(usize, f32);158impl InferenceResult {159pub fn class_id(&self) -> usize {160self.0161}162pub fn probability(&self) -> f32 {163self.1164}165}166167168