Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/examples/classification-example-winml/src/main.rs
2459 views
1
use image::{DynamicImage, RgbImage};
2
use ndarray::Array;
3
use std::{fs, time::Instant};
4
5
pub fn main() {
6
// Load model from a file.
7
let graph =
8
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Onnx, wasi_nn::ExecutionTarget::CPU)
9
.build_from_files(["fixture/mobilenet.onnx"])
10
.unwrap();
11
12
let mut context = graph.init_execution_context().unwrap();
13
println!("Created an execution context.");
14
15
// Read image from file and convert it to tensor data.
16
let image_data = fs::read("fixture/kitten.png").unwrap();
17
18
// Preprocessing. Normalize data based on model requirements https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet#preprocessing
19
let tensor_data = preprocess(
20
image_data.as_slice(),
21
224,
22
224,
23
&[0.485, 0.456, 0.406],
24
&[0.229, 0.224, 0.225],
25
);
26
println!("Read input tensor, size in bytes: {}", tensor_data.len());
27
28
context
29
.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &tensor_data)
30
.unwrap();
31
32
// Execute the inference.
33
let before_compute = Instant::now();
34
context.compute().unwrap();
35
println!(
36
"Executed graph inference, took {} ms.",
37
before_compute.elapsed().as_millis()
38
);
39
40
// Retrieve the output.
41
let mut output_buffer = vec![0f32; 1000];
42
context.get_output(0, &mut output_buffer[..]).unwrap();
43
44
// Postprocessing. Calculating the softmax probability scores.
45
let result = postprocess(output_buffer);
46
47
// Load labels for classification
48
let labels_file = fs::read("fixture/synset.txt").unwrap();
49
let labels_str = String::from_utf8(labels_file).unwrap();
50
let labels: Vec<String> = labels_str
51
.lines()
52
.map(|line| {
53
let words: Vec<&str> = line.split_whitespace().collect();
54
words[1..].join(" ")
55
})
56
.collect();
57
58
println!(
59
"Found results, sorted top 5: {:?}",
60
&sort_results(&result, &labels)[..5]
61
)
62
}
63
64
// Sort the buffer of probabilities. The graph places the match probability for each class at the
65
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
66
// to a wrapping InferenceResult and sort the results.
67
fn sort_results(buffer: &[f32], labels: &Vec<String>) -> Vec<InferenceResult> {
68
let mut results: Vec<InferenceResult> = buffer
69
.iter()
70
.enumerate()
71
.map(|(c, p)| InferenceResult(labels[c].clone(), *p))
72
.collect();
73
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
74
results
75
}
76
77
// Resize image to height x width, and then converts the pixel precision to FP32, normalize with
78
// given mean and std. The resulting RGB pixel vector is then returned.
79
fn preprocess(image: &[u8], height: u32, width: u32, mean: &[f32], std: &[f32]) -> Vec<u8> {
80
let dyn_img: DynamicImage = image::load_from_memory(image).unwrap().resize_exact(
81
width,
82
height,
83
image::imageops::Triangle,
84
);
85
let rgb_img: RgbImage = dyn_img.to_rgb8();
86
87
// Get an array of the pixel values
88
let raw_u8_arr: &[u8] = &rgb_img.as_raw()[..];
89
90
// Create an array to hold the f32 value of those pixels
91
let bytes_required = raw_u8_arr.len() * 4;
92
let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];
93
94
// Read the number as a f32 and break it into u8 bytes
95
for i in 0..raw_u8_arr.len() {
96
let u8_f32: f32 = raw_u8_arr[i] as f32;
97
let rgb_iter = i % 3;
98
99
// Normalize the pixel
100
let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];
101
102
// Convert it to u8 bytes and write it with new shape
103
let u8_bytes = norm_u8_f32.to_ne_bytes();
104
for j in 0..4 {
105
u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];
106
}
107
}
108
109
return u8_f32_arr;
110
}
111
112
fn postprocess(output_tensor: Vec<f32>) -> Vec<f32> {
113
// Post-Processing requirement: compute softmax to inferencing output
114
let output_shape = [1, 1000, 1, 1];
115
let exp_output = Array::from_shape_vec(output_shape, output_tensor)
116
.unwrap()
117
.mapv(|x| x.exp());
118
let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));
119
let softmax_output = exp_output / &sum_exp_output;
120
softmax_output.into_raw_vec()
121
}
122
123
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
124
let chunks: Vec<&[u8]> = data.chunks(4).collect();
125
let v: Vec<f32> = chunks
126
.into_iter()
127
.map(|c| f32::from_ne_bytes(c.try_into().unwrap()))
128
.collect();
129
130
v.into_iter().collect()
131
}
132
133
// A wrapper for class ID and match probabilities.
134
#[derive(Debug, PartialEq)]
135
struct InferenceResult(String, f32);
136
137