Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/examples/classification-component-onnx/src/main.rs
1692 views
1
#![allow(unused_braces)]
2
use image::ImageReader;
3
use image::{DynamicImage, RgbImage};
4
use ndarray::{Array, Dim};
5
use std::fs;
6
use std::io::BufRead;
7
8
const IMG_PATH: &str = "fixture/images/dog.jpg";
9
10
wit_bindgen::generate!({
11
path: "../../wit",
12
world: "ml",
13
});
14
15
use self::wasi::nn::{
16
graph::{Graph, GraphBuilder, load, ExecutionTarget, GraphEncoding},
17
tensor::{Tensor, TensorData, TensorDimensions, TensorType},
18
};
19
20
fn main() {
21
// Load the ONNX model - SqueezeNet 1.1-7
22
// Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet
23
let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap();
24
println!("Read ONNX model, size in bytes: {}", model.len());
25
26
let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap();
27
println!("Loaded graph into wasi-nn");
28
29
let exec_context = Graph::init_execution_context(&graph).unwrap();
30
println!("Created wasi-nn execution context.");
31
32
// Load SquezeNet 1000 labels used for classification
33
let labels = fs::read("fixture/labels/squeezenet1.1-7.txt").unwrap();
34
let class_labels: Vec<String> = labels.lines().map(|line| line.unwrap()).collect();
35
println!("Read ONNX Labels, # of labels: {}", class_labels.len());
36
37
// Prepare WASI-NN tensor - Tensor data is always a bytes vector
38
let dimensions: TensorDimensions = vec![1, 3, 224, 224];
39
let data: TensorData = image_to_tensor(IMG_PATH.to_string(), 224, 224);
40
let tensor = Tensor::new(
41
&dimensions,
42
TensorType::Fp32,
43
&data,
44
);
45
let input_tensor = vec!(("data".to_string(), tensor));
46
// Execute the inferencing
47
let output_tensor_vec = exec_context.compute(input_tensor).unwrap();
48
println!("Executed graph inference");
49
50
let output_tensor = output_tensor_vec.iter().find_map(|(tensor_name, tensor)| {
51
if tensor_name == "squeezenet0_flatten0_reshape0" {
52
Some(tensor)
53
} else {
54
None
55
}
56
});
57
let output_data = output_tensor.expect("No output tensor").data();
58
59
println!("Retrieved output data with length: {}", output_data.len());
60
let output_f32 = bytes_to_f32_vec(output_data);
61
62
let output_shape = [1, 1000, 1, 1];
63
let output_tensor = Array::from_shape_vec(output_shape, output_f32).unwrap();
64
65
// Post-Processing requirement: compute softmax to inferencing output
66
let exp_output = output_tensor.mapv(|x| x.exp());
67
let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));
68
let softmax_output = exp_output / &sum_exp_output;
69
70
let mut sorted = softmax_output
71
.axis_iter(ndarray::Axis(1))
72
.enumerate()
73
.into_iter()
74
.map(|(i, v)| (i, v[Dim([0, 0, 0])]))
75
.collect::<Vec<(_, _)>>();
76
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
77
78
for (index, probability) in sorted.iter().take(3) {
79
println!(
80
"Index: {} - Probability: {}",
81
class_labels[*index], probability
82
);
83
}
84
}
85
86
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
87
let chunks: Vec<&[u8]> = data.chunks(4).collect();
88
let v: Vec<f32> = chunks
89
.into_iter()
90
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
91
.collect();
92
93
v.into_iter().collect()
94
}
95
96
// Take the image located at 'path', open it, resize it to height x width, and then converts
97
// the pixel precision to FP32. The resulting BGR pixel vector is then returned.
98
fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> {
99
let pixels = ImageReader::open(path).unwrap().decode().unwrap();
100
let dyn_img: DynamicImage = pixels.resize_exact(width, height, image::imageops::Triangle);
101
let bgr_img: RgbImage = dyn_img.to_rgb8();
102
103
// Get an array of the pixel values
104
let raw_u8_arr: &[u8] = &bgr_img.as_raw()[..];
105
106
// Create an array to hold the f32 value of those pixels
107
let bytes_required = raw_u8_arr.len() * 4;
108
let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];
109
110
// Normalizing values for the model
111
let mean = [0.485, 0.456, 0.406];
112
let std = [0.229, 0.224, 0.225];
113
114
// Read the number as a f32 and break it into u8 bytes
115
for i in 0..raw_u8_arr.len() {
116
let u8_f32: f32 = raw_u8_arr[i] as f32;
117
let rgb_iter = i % 3;
118
119
// Normalize the pixel
120
let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];
121
122
// Convert it to u8 bytes and write it with new shape
123
let u8_bytes = norm_u8_f32.to_ne_bytes();
124
for j in 0..4 {
125
u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];
126
}
127
}
128
129
return u8_f32_arr;
130
}
131
132