Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/examples/classification-example-pytorch/src/main.rs
1692 views
1
use anyhow::{Error, Result};
2
use image::{DynamicImage, RgbImage};
3
use std::fs;
4
use wasi_nn::{self, ExecutionTarget, GraphBuilder, GraphEncoding};
5
6
pub fn main() -> Result<(), Error> {
7
// Read the model file (Resnet18)
8
let model = fs::read("fixture/model.pt")?;
9
let graph = GraphBuilder::new(GraphEncoding::Pytorch, ExecutionTarget::CPU)
10
.build_from_bytes(&[&model])?;
11
12
let mut context = graph.init_execution_context()?;
13
14
let image = fs::read("fixture/kitten.png")?;
15
// Preprocessing. Normalize data based on model requirements https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet#preprocessing
16
let tensor_data = preprocess(
17
image.as_slice(),
18
224,
19
224,
20
&[0.485, 0.456, 0.406],
21
&[0.229, 0.224, 0.225],
22
);
23
let precision = wasi_nn::TensorType::F32;
24
// Resnet18 model input is NCHW
25
let shape = &[1, 3, 224, 224];
26
// Set the input tensor. PyTorch models do not use ports, so it is set to 0 here.
27
// Tensors are passed to the model, and the model's forward method processes these tensors.
28
context.set_input(0, precision, shape, &tensor_data)?;
29
context.compute()?;
30
let mut output_buffer = vec![0f32; 1000];
31
context.get_output(0, &mut output_buffer[..])?;
32
let result = softmax(output_buffer);
33
println!(
34
"Found results, sorted top 5: {:?}",
35
&sort_results(&result)[..5]
36
);
37
Ok(())
38
}
39
40
// Resize image to height x width, and then converts the pixel precision to FP32, normalize with
41
// given mean and std. The resulting RGB pixel vector is then returned.
42
fn preprocess(image: &[u8], height: u32, width: u32, mean: &[f32], std: &[f32]) -> Vec<u8> {
43
let dyn_img: DynamicImage = image::load_from_memory(image).unwrap().resize_exact(
44
width,
45
height,
46
image::imageops::Triangle,
47
);
48
let rgb_img: RgbImage = dyn_img.to_rgb8();
49
50
// Get an array of the pixel values
51
let raw_u8_arr: &[u8] = &rgb_img.as_raw()[..];
52
53
// Create an array to hold the f32 value of those pixels
54
let bytes_required = raw_u8_arr.len() * 4;
55
let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];
56
57
// Read the number as a f32 and break it into u8 bytes
58
for i in 0..raw_u8_arr.len() {
59
let u8_f32: f32 = raw_u8_arr[i] as f32;
60
let rgb_iter = i % 3;
61
62
// Normalize the pixel
63
let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];
64
65
// Convert it to u8 bytes and write it with new shape
66
let u8_bytes = norm_u8_f32.to_ne_bytes();
67
for j in 0..4 {
68
u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];
69
}
70
}
71
u8_f32_arr
72
}
73
74
fn softmax(output_tensor: Vec<f32>) -> Vec<f32> {
75
let max_val = output_tensor
76
.iter()
77
.cloned()
78
.fold(f32::NEG_INFINITY, f32::max);
79
80
// Compute the exponential of each element subtracted by max_val for numerical stability.
81
let exps: Vec<f32> = output_tensor.iter().map(|&x| (x - max_val).exp()).collect();
82
83
// Compute the sum of the exponentials.
84
let sum_exps: f32 = exps.iter().sum();
85
86
// Normalize each element to get the probabilities.
87
exps.iter().map(|&exp| exp / sum_exps).collect()
88
}
89
90
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
91
let mut results: Vec<InferenceResult> = buffer
92
.iter()
93
.enumerate()
94
.map(|(c, p)| InferenceResult(c, *p))
95
.collect();
96
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
97
results
98
}
99
100
#[derive(Debug, PartialEq)]
101
struct InferenceResult(usize, f32);
102
103