Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/backend/onnx.rs
2459 views
1
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate.
2
3
use super::{
4
BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, NamedTensor,
5
};
6
use crate::backend::{Id, read};
7
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
8
use crate::{ExecutionContext, Graph};
9
use anyhow::Context;
10
use ort::{GraphOptimizationLevel, Session, inputs};
11
use std::path::Path;
12
use std::sync::{Arc, Mutex};
13
14
#[derive(Default)]
15
pub struct OnnxBackend();
16
unsafe impl Send for OnnxBackend {}
17
unsafe impl Sync for OnnxBackend {}
18
19
impl BackendInner for OnnxBackend {
20
fn encoding(&self) -> GraphEncoding {
21
GraphEncoding::Onnx
22
}
23
24
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
25
if builders.len() != 1 {
26
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into());
27
}
28
29
let session = Session::builder()?
30
.with_optimization_level(GraphOptimizationLevel::Level3)?
31
.commit_from_memory(builders[0])?;
32
33
let box_: Box<dyn BackendGraph> =
34
Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target));
35
Ok(box_.into())
36
}
37
38
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {
39
Some(self)
40
}
41
}
42
43
impl BackendFromDir for OnnxBackend {
44
fn load_from_dir(
45
&mut self,
46
path: &Path,
47
target: ExecutionTarget,
48
) -> Result<Graph, BackendError> {
49
let model = read(&path.join("model.onnx"))?;
50
self.load(&[&model], target)
51
}
52
}
53
54
struct OnnxGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget);
55
unsafe impl Send for OnnxGraph {}
56
unsafe impl Sync for OnnxGraph {}
57
58
impl BackendGraph for OnnxGraph {
59
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
60
let session = self.0.lock().unwrap();
61
// We need to hold on to the names of the inputs in order for
62
// `set_input` to work with both indexes and names. Having the
63
// dimensions and type around is useful for validation but could be
64
// retrieved from the session.
65
let mut inputs = vec![];
66
for input in &session.inputs {
67
let shape = Shape::from_onnx_input(input)?;
68
inputs.push(TensorSlot {
69
shape,
70
tensor: None,
71
});
72
}
73
// We need to keep track of the output shapes since they are used for
74
// creating the output tensor.
75
let mut outputs = vec![];
76
for output in &session.outputs {
77
let shape = Shape::from_onnx_output(output)?;
78
outputs.push(TensorSlot {
79
shape,
80
tensor: None,
81
});
82
}
83
let box_: Box<dyn BackendExecutionContext> = Box::new(OnnxExecutionContext {
84
session: self.0.clone(),
85
inputs,
86
outputs,
87
});
88
Ok(box_.into())
89
}
90
}
91
92
struct OnnxExecutionContext {
93
session: Arc<Mutex<Session>>,
94
inputs: Vec<TensorSlot>,
95
outputs: Vec<TensorSlot>,
96
}
97
98
unsafe impl Send for OnnxExecutionContext {}
99
unsafe impl Sync for OnnxExecutionContext {}
100
101
impl OnnxExecutionContext {
102
/// Helper function for finding the internal index of a tensor by [`Id`].
103
fn find(&self, id: Id, list: &[TensorSlot]) -> Result<usize, BackendError> {
104
let index = match id {
105
Id::Index(i) => {
106
let i = i as usize;
107
if i < list.len() {
108
i
109
} else {
110
return Err(BackendError::BackendAccess(anyhow::anyhow!(
111
"incorrect tensor index: {i} >= {}",
112
list.len()
113
)));
114
}
115
}
116
Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| {
117
BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {n}"))
118
})?,
119
};
120
Ok(index)
121
}
122
}
123
124
impl BackendExecutionContext for OnnxExecutionContext {
125
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
126
let index = self.find(id, &self.inputs)?;
127
let input = &mut self.inputs[index];
128
if let Err(e) = input.shape.matches(tensor) {
129
return Err(e.into());
130
}
131
// Hold the tensor data on the context until `compute` is called.
132
input.tensor.replace(tensor.clone());
133
Ok(())
134
}
135
136
fn compute(
137
&mut self,
138
inputs: Option<Vec<NamedTensor>>,
139
) -> Result<Option<Vec<NamedTensor>>, BackendError> {
140
match inputs {
141
// WIT
142
Some(inputs) => {
143
for slot in &mut self.inputs {
144
slot.tensor = None;
145
}
146
for input in &inputs {
147
let index = self
148
.inputs
149
.iter()
150
.position(|slot| slot.shape.name == input.name);
151
let index = match index {
152
Some(idx) => idx,
153
None => {
154
// Try to convert name to index
155
if let Ok(idx) = input.name.parse::<usize>() {
156
if idx < self.inputs.len() {
157
idx
158
} else {
159
return Err(BackendError::BackendAccess(anyhow::anyhow!(
160
"Input index out of range: {}",
161
idx
162
)));
163
}
164
} else {
165
return Err(BackendError::BackendAccess(anyhow::anyhow!(
166
"Unknown input tensor name: {}",
167
input.name
168
)));
169
}
170
}
171
};
172
173
let input_slot = &mut self.inputs[index];
174
if let Err(e) = input_slot.shape.matches(&input.tensor) {
175
return Err(e.into());
176
}
177
input_slot.tensor.replace(input.tensor.clone());
178
}
179
180
let mut session_inputs: Vec<ort::SessionInputValue<'_>> = vec![];
181
for i in &self.inputs {
182
session_inputs.extend(to_input_value(i)?);
183
}
184
let session = self.session.lock().unwrap();
185
let session_outputs = session.run(session_inputs.as_slice())?;
186
187
let mut output_tensors = Vec::new();
188
for i in 0..self.outputs.len() {
189
// TODO: fix preexisting gap--this only handles f32 tensors.
190
let raw: (Vec<i64>, &[f32]) = session_outputs[i].try_extract_raw_tensor()?;
191
let f32s = raw.1.to_vec();
192
let output = &mut self.outputs[i];
193
let tensor = Tensor {
194
dimensions: output.shape.dimensions_as_u32()?,
195
ty: output.shape.ty,
196
data: f32_vec_to_bytes(f32s),
197
};
198
output.tensor.replace(tensor.clone());
199
output_tensors.push(NamedTensor {
200
name: output.shape.name.clone(),
201
tensor,
202
});
203
}
204
Ok(Some(output_tensors))
205
}
206
207
// WITX
208
None => {
209
let mut session_inputs: Vec<ort::SessionInputValue<'_>> = vec![];
210
for i in &self.inputs {
211
session_inputs.extend(to_input_value(i)?);
212
}
213
let session = self.session.lock().unwrap();
214
let session_outputs = session.run(session_inputs.as_slice())?;
215
for i in 0..self.outputs.len() {
216
// TODO: fix preexisting gap--this only handles f32 tensors.
217
let raw: (Vec<i64>, &[f32]) = session_outputs[i].try_extract_raw_tensor()?;
218
let f32s = raw.1.to_vec();
219
let output = &mut self.outputs[i];
220
output.tensor.replace(Tensor {
221
dimensions: output.shape.dimensions_as_u32()?,
222
ty: output.shape.ty,
223
data: f32_vec_to_bytes(f32s),
224
});
225
}
226
Ok(None)
227
}
228
}
229
}
230
231
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
232
let index = self.find(id, &self.outputs)?;
233
let output = &self.outputs[index];
234
if let Some(tensor) = &output.tensor {
235
Ok(tensor.clone())
236
} else {
237
Err(BackendError::BackendAccess(anyhow::anyhow!(
238
"missing output tensor: {}; has `compute` been called?",
239
output.shape.name
240
)))
241
}
242
}
243
}
244
245
impl From<ort::Error> for BackendError {
246
fn from(e: ort::Error) -> Self {
247
BackendError::BackendAccess(e.into())
248
}
249
}
250
251
/// Holds a slot for ONNX session inputs and outputs.
252
///
253
/// TODO: it seems unfortunate that we have to "hold" some extra data per
254
/// session but in the input case, this is necessary for name-based indexing.
255
struct TensorSlot {
256
shape: Shape,
257
tensor: Option<Tensor>,
258
}
259
260
/// Describes a tensor in ONNX terms.
261
struct Shape {
262
name: String,
263
dimensions: Vec<i64>,
264
ty: TensorType,
265
}
266
267
impl Shape {
268
fn from_onnx_input(input: &ort::Input) -> Result<Self, BackendError> {
269
let name = input.name.clone();
270
let (dimensions, ty) = convert_value_type(&input.input_type)?;
271
Ok(Self {
272
name,
273
dimensions,
274
ty,
275
})
276
}
277
278
fn from_onnx_output(output: &ort::Output) -> Result<Self, BackendError> {
279
let name = output.name.clone();
280
let (dimensions, ty) = convert_value_type(&output.output_type)?;
281
Ok(Self {
282
name,
283
dimensions,
284
ty,
285
})
286
}
287
288
fn dimensions_as_u32(&self) -> Result<Vec<u32>, BackendError> {
289
self.dimensions
290
.iter()
291
.map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) })
292
.collect()
293
}
294
295
fn matches(&self, tensor: &Tensor) -> anyhow::Result<()> {
296
if self.dimensions.len() != tensor.dimensions.len() {
297
return Err(anyhow::anyhow!(
298
"input tensor cardinality does not match model: {:?} != {:?}",
299
self.dimensions,
300
tensor.dimensions
301
));
302
} else {
303
for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) {
304
let tensor_dim = tensor_dim as i64;
305
if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim {
306
return Err(anyhow::anyhow!(
307
"input tensor dimensions do not match model: {:?} != {:?}",
308
self.dimensions,
309
tensor.dimensions
310
));
311
}
312
}
313
}
314
if self.ty != tensor.ty {
315
return Err(anyhow::anyhow!(
316
"input tensor type does not match model: {:?} != {:?}",
317
self.ty,
318
tensor.ty
319
));
320
}
321
Ok(())
322
}
323
}
324
325
fn convert_value_type(vt: &ort::ValueType) -> Result<(Vec<i64>, TensorType), BackendError> {
326
match vt {
327
ort::ValueType::Tensor { ty, dimensions } => {
328
let dims = dimensions.clone();
329
let ty = (*ty).try_into()?;
330
Ok((dims, ty))
331
}
332
_ => Err(BackendError::BackendAccess(anyhow::anyhow!(
333
"unsupported input type: {vt:?}"
334
))),
335
}
336
}
337
338
fn convert_i64(i: &i64) -> Result<u32, BackendError> {
339
u32::try_from(*i).map_err(|d| -> BackendError {
340
anyhow::anyhow!("unable to convert dimension to u32: {d}").into()
341
})
342
}
343
344
impl TryFrom<ort::TensorElementType> for TensorType {
345
type Error = BackendError;
346
fn try_from(ty: ort::TensorElementType) -> Result<Self, Self::Error> {
347
match ty {
348
ort::TensorElementType::Float32 => Ok(TensorType::Fp32),
349
ort::TensorElementType::Float64 => Ok(TensorType::Fp64),
350
ort::TensorElementType::Uint8 => Ok(TensorType::U8),
351
ort::TensorElementType::Int32 => Ok(TensorType::I32),
352
ort::TensorElementType::Int64 => Ok(TensorType::I64),
353
_ => Err(BackendError::BackendAccess(anyhow::anyhow!(
354
"unsupported tensor type: {ty:?}"
355
))),
356
}
357
}
358
}
359
360
fn to_input_value(slot: &TensorSlot) -> Result<[ort::SessionInputValue<'_>; 1], BackendError> {
361
match &slot.tensor {
362
Some(tensor) => match tensor.ty {
363
TensorType::Fp32 => {
364
let data = bytes_to_f32_vec(tensor.data.to_vec());
365
let dimensions = tensor
366
.dimensions
367
.iter()
368
.map(|d| *d as i64) // TODO: fewer conversions
369
.collect::<Vec<i64>>();
370
Ok(inputs![(dimensions, Arc::new(data.into_boxed_slice()))]
371
.context("failed to create ONNX session input")?)
372
}
373
_ => {
374
unimplemented!("{:?} not supported by ONNX", tensor.ty);
375
}
376
},
377
None => {
378
return Err(BackendError::BackendAccess(anyhow::anyhow!(
379
"missing input tensor: {}",
380
slot.shape.name
381
)));
382
}
383
}
384
}
385
386
pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {
387
let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect();
388
let result: Vec<u8> = chunks.iter().flatten().copied().collect();
389
result
390
}
391
392
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
393
let chunks: Vec<&[u8]> = data.chunks(4).collect();
394
let v: Vec<f32> = chunks
395
.into_iter()
396
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
397
.collect();
398
399
v.into_iter().collect()
400
}
401
402
/// Returns whether the dimension is dynamic.
403
///
404
/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the
405
/// value of a tensor dimension is user-defined, not fixed by the model. This is
406
/// useful for batching up several inference requests, e.g. When `ort` returns a
407
/// dimension of this kind, though, it uses `-1` to indicate that the dimension
408
/// is dynamic.
409
///
410
/// [dimensional variables]:
411
/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes
412
fn is_dynamic_dimension(d: i64) -> bool {
413
d == -1
414
}
415
416