Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/examples/classification-component-onnx/src/main.rs
3101 views
1
#![allow(unused_braces)]
2
use image::ImageReader;
3
use image::{DynamicImage, RgbImage};
4
use ndarray::{Array, Dim};
5
use std::fs;
6
use std::io::BufRead;
7
8
const IMG_PATH: &str = "fixture/images/dog.jpg";
9
10
wit_bindgen::generate!({
11
path: "../../wit",
12
world: "ml",
13
});
14
15
use self::wasi::nn::{
16
graph::{Graph, GraphBuilder, load, ExecutionTarget, GraphEncoding},
17
tensor::{Tensor, TensorData, TensorDimensions, TensorType},
18
};
19
20
/// Determine execution target from command-line argument
21
/// Usage: wasm_module [cpu|gpu|cuda]
22
fn get_execution_target() -> ExecutionTarget {
23
let args: Vec<String> = std::env::args().collect();
24
25
// First argument (index 0) is the program name, second (index 1) is the target
26
// Ignore any arguments after index 1
27
if args.len() >= 2 {
28
match args[1].to_lowercase().as_str() {
29
"gpu" | "cuda" => {
30
println!("Using GPU (CUDA) execution target from argument");
31
return ExecutionTarget::Gpu;
32
}
33
"cpu" => {
34
println!("Using CPU execution target from argument");
35
return ExecutionTarget::Cpu;
36
}
37
_ => {
38
println!("Unknown execution target '{}', defaulting to CPU", args[1]);
39
}
40
}
41
} else {
42
println!("No execution target specified, defaulting to CPU");
43
println!("Usage: <program> [cpu|gpu|cuda]");
44
}
45
46
ExecutionTarget::Cpu
47
}
48
49
fn main() {
50
// Load the ONNX model - SqueezeNet 1.1-7
51
// Full details: https://github.com/onnx/models/tree/main/vision/classification/squeezenet
52
let model: GraphBuilder = fs::read("fixture/models/squeezenet1.1-7.onnx").unwrap();
53
println!("Read ONNX model, size in bytes: {}", model.len());
54
55
// Determine execution target
56
let execution_target = get_execution_target();
57
58
let graph = load(&[model], GraphEncoding::Onnx, execution_target).unwrap();
59
println!("Loaded graph into wasi-nn with {:?} target", execution_target);
60
61
let exec_context = Graph::init_execution_context(&graph).unwrap();
62
println!("Created wasi-nn execution context.");
63
64
// Load SquezeNet 1000 labels used for classification
65
let labels = fs::read("fixture/labels/squeezenet1.1-7.txt").unwrap();
66
let class_labels: Vec<String> = labels.lines().map(|line| line.unwrap()).collect();
67
println!("Read ONNX Labels, # of labels: {}", class_labels.len());
68
69
// Prepare WASI-NN tensor - Tensor data is always a bytes vector
70
let dimensions: TensorDimensions = vec![1, 3, 224, 224];
71
let data: TensorData = image_to_tensor(IMG_PATH.to_string(), 224, 224);
72
let tensor = Tensor::new(
73
&dimensions,
74
TensorType::Fp32,
75
&data,
76
);
77
let input_tensor = vec!(("data".to_string(), tensor));
78
// Execute the inferencing
79
let output_tensor_vec = exec_context.compute(input_tensor).unwrap();
80
println!("Executed graph inference");
81
82
let output_tensor = output_tensor_vec.iter().find_map(|(tensor_name, tensor)| {
83
if tensor_name == "squeezenet0_flatten0_reshape0" {
84
Some(tensor)
85
} else {
86
None
87
}
88
});
89
let output_data = output_tensor.expect("No output tensor").data();
90
91
println!("Retrieved output data with length: {}", output_data.len());
92
let output_f32 = bytes_to_f32_vec(output_data);
93
94
let output_shape = [1, 1000, 1, 1];
95
let output_tensor = Array::from_shape_vec(output_shape, output_f32).unwrap();
96
97
// Post-Processing requirement: compute softmax to inferencing output
98
let exp_output = output_tensor.mapv(|x| x.exp());
99
let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));
100
let softmax_output = exp_output / &sum_exp_output;
101
102
let mut sorted = softmax_output
103
.axis_iter(ndarray::Axis(1))
104
.enumerate()
105
.into_iter()
106
.map(|(i, v)| (i, v[Dim([0, 0, 0])]))
107
.collect::<Vec<(_, _)>>();
108
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
109
110
for (index, probability) in sorted.iter().take(3) {
111
println!(
112
"Index: {} - Probability: {}",
113
class_labels[*index], probability
114
);
115
}
116
}
117
118
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
119
let chunks: Vec<&[u8]> = data.chunks(4).collect();
120
let v: Vec<f32> = chunks
121
.into_iter()
122
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
123
.collect();
124
125
v.into_iter().collect()
126
}
127
128
// Take the image located at 'path', open it, resize it to height x width, and then converts
129
// the pixel precision to FP32. The resulting BGR pixel vector is then returned.
130
fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> {
131
let pixels = ImageReader::open(path).unwrap().decode().unwrap();
132
let dyn_img: DynamicImage = pixels.resize_exact(width, height, image::imageops::Triangle);
133
let bgr_img: RgbImage = dyn_img.to_rgb8();
134
135
// Get an array of the pixel values
136
let raw_u8_arr: &[u8] = &bgr_img.as_raw()[..];
137
138
// Create an array to hold the f32 value of those pixels
139
let bytes_required = raw_u8_arr.len() * 4;
140
let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];
141
142
// Normalizing values for the model
143
let mean = [0.485, 0.456, 0.406];
144
let std = [0.229, 0.224, 0.225];
145
146
// Read the number as a f32 and break it into u8 bytes
147
for i in 0..raw_u8_arr.len() {
148
let u8_f32: f32 = raw_u8_arr[i] as f32;
149
let rgb_iter = i % 3;
150
151
// Normalize the pixel
152
let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];
153
154
// Convert it to u8 bytes and write it with new shape
155
let u8_bytes = norm_u8_f32.to_ne_bytes();
156
for j in 0..4 {
157
u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];
158
}
159
}
160
161
return u8_f32_arr;
162
}
163
164