Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/backend/onnx.rs
3119 views
1
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate.
2
3
use super::{
4
BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, NamedTensor,
5
};
6
use crate::backend::{Id, read};
7
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
8
use crate::{ExecutionContext, Graph};
9
use ort::{
10
execution_providers::{CPUExecutionProvider, ExecutionProviderDispatch},
11
inputs,
12
session::{Input, Output},
13
session::{Session, SessionInputValue, builder::GraphOptimizationLevel},
14
tensor::TensorElementType,
15
value::{Tensor as OrtTensor, ValueType},
16
};
17
18
#[cfg(feature = "onnx-cuda")]
19
use ort::execution_providers::CUDAExecutionProvider;
20
21
use std::path::Path;
22
use std::sync::{Arc, Mutex};
23
24
#[derive(Default)]
25
pub struct OnnxBackend();
26
unsafe impl Send for OnnxBackend {}
27
unsafe impl Sync for OnnxBackend {}
28
29
impl BackendInner for OnnxBackend {
30
fn encoding(&self) -> GraphEncoding {
31
GraphEncoding::Onnx
32
}
33
34
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
35
if builders.len() != 1 {
36
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));
37
}
38
39
// Configure execution providers based on target
40
let execution_providers = configure_execution_providers(target)?;
41
42
let session = Session::builder()?
43
.with_execution_providers(execution_providers)?
44
.with_optimization_level(GraphOptimizationLevel::Level3)?
45
.commit_from_memory(builders[0])?;
46
47
let box_: Box<dyn BackendGraph> =
48
Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target));
49
Ok(box_.into())
50
}
51
52
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {
53
Some(self)
54
}
55
}
56
57
/// Configure execution providers based on the target
58
fn configure_execution_providers(
59
target: ExecutionTarget,
60
) -> Result<Vec<ExecutionProviderDispatch>, BackendError> {
61
match target {
62
ExecutionTarget::Cpu => {
63
// Use CPU execution provider with default configuration
64
tracing::debug!("Using CPU execution provider");
65
Ok(vec![CPUExecutionProvider::default().build()])
66
}
67
ExecutionTarget::Gpu => {
68
#[cfg(feature = "onnx-cuda")]
69
{
70
// Use CUDA execution provider for GPU acceleration
71
tracing::debug!("Using Nvidia GPU CUDA execution provider");
72
Ok(vec![CUDAExecutionProvider::default().build()])
73
}
74
#[cfg(not(feature = "onnx-cuda"))]
75
{
76
tracing::warn!("GPU CUDA execution provider is not enabled, falling back to CPU");
77
Ok(vec![CPUExecutionProvider::default().build()])
78
}
79
}
80
ExecutionTarget::Tpu => {
81
tracing::warn!(
82
"TPU execution target is not supported for ONNX backend yet, falling back to CPU"
83
);
84
Ok(vec![CPUExecutionProvider::default().build()])
85
}
86
}
87
}
88
89
impl BackendFromDir for OnnxBackend {
90
fn load_from_dir(
91
&mut self,
92
path: &Path,
93
target: ExecutionTarget,
94
) -> Result<Graph, BackendError> {
95
let model = read(&path.join("model.onnx"))?;
96
self.load(&[&model], target)
97
}
98
}
99
100
struct OnnxGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget);
101
unsafe impl Send for OnnxGraph {}
102
unsafe impl Sync for OnnxGraph {}
103
104
impl BackendGraph for OnnxGraph {
105
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
106
let session = self.0.lock().unwrap();
107
// We need to hold on to the names of the inputs in order for
108
// `set_input` to work with both indexes and names. Having the
109
// dimensions and type around is useful for validation but could be
110
// retrieved from the session.
111
let mut inputs = vec![];
112
for input in &session.inputs {
113
let shape = Shape::from_onnx_input(input)?;
114
inputs.push(TensorSlot {
115
shape,
116
tensor: None,
117
});
118
}
119
// We need to keep track of the output shapes since they are used for
120
// creating the output tensor.
121
let mut outputs = vec![];
122
for output in &session.outputs {
123
let shape = Shape::from_onnx_output(output)?;
124
outputs.push(TensorSlot {
125
shape,
126
tensor: None,
127
});
128
}
129
let box_: Box<dyn BackendExecutionContext> = Box::new(OnnxExecutionContext {
130
session: self.0.clone(),
131
inputs,
132
outputs,
133
});
134
Ok(box_.into())
135
}
136
}
137
138
struct OnnxExecutionContext {
139
session: Arc<Mutex<Session>>,
140
inputs: Vec<TensorSlot>,
141
outputs: Vec<TensorSlot>,
142
}
143
144
unsafe impl Send for OnnxExecutionContext {}
145
unsafe impl Sync for OnnxExecutionContext {}
146
147
impl OnnxExecutionContext {
148
/// Helper function for finding the internal index of a tensor by [`Id`].
149
fn find(&self, id: Id, list: &[TensorSlot]) -> Result<usize, BackendError> {
150
let index = match id {
151
Id::Index(i) => {
152
let i = i as usize;
153
if i < list.len() {
154
i
155
} else {
156
return Err(BackendError::BackendAccess(wasmtime::format_err!(
157
"incorrect tensor index: {i} >= {}",
158
list.len()
159
)));
160
}
161
}
162
Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| {
163
BackendError::BackendAccess(wasmtime::format_err!("unknown tensor name: {n}"))
164
})?,
165
};
166
Ok(index)
167
}
168
}
169
170
impl BackendExecutionContext for OnnxExecutionContext {
171
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
172
let index = self.find(id, &self.inputs)?;
173
let input = &mut self.inputs[index];
174
if let Err(e) = input.shape.matches(tensor) {
175
return Err(e.into());
176
}
177
// Hold the tensor data on the context until `compute` is called.
178
input.tensor.replace(tensor.clone());
179
Ok(())
180
}
181
182
fn compute(
183
&mut self,
184
inputs: Option<Vec<NamedTensor>>,
185
) -> Result<Option<Vec<NamedTensor>>, BackendError> {
186
fn dimensions_as_u32(shape: &ort::tensor::Shape) -> Result<Vec<u32>, BackendError> {
187
(*shape)
188
.iter()
189
.map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) })
190
.collect()
191
}
192
193
match inputs {
194
// WIT
195
Some(inputs) => {
196
for slot in &mut self.inputs {
197
slot.tensor = None;
198
}
199
for input in &inputs {
200
let index = self
201
.inputs
202
.iter()
203
.position(|slot| slot.shape.name == input.name);
204
let index = match index {
205
Some(idx) => idx,
206
None => {
207
// Try to convert name to index
208
if let Ok(idx) = input.name.parse::<usize>() {
209
if idx < self.inputs.len() {
210
idx
211
} else {
212
return Err(BackendError::BackendAccess(
213
wasmtime::format_err!("Input index out of range: {idx}"),
214
));
215
}
216
} else {
217
return Err(BackendError::BackendAccess(wasmtime::format_err!(
218
"Unknown input tensor name: {}",
219
input.name
220
)));
221
}
222
}
223
};
224
225
let input_slot = &mut self.inputs[index];
226
if let Err(e) = input_slot.shape.matches(&input.tensor) {
227
return Err(e.into());
228
}
229
input_slot.tensor.replace(input.tensor.clone());
230
}
231
232
let mut session_inputs: Vec<SessionInputValue<'_>> = vec![];
233
for i in &self.inputs {
234
session_inputs.extend(to_input_value(i)?);
235
}
236
let mut session = self.session.lock().unwrap();
237
let session_outputs = session.run(session_inputs.as_slice())?;
238
239
let mut output_tensors = Vec::new();
240
for i in 0..self.outputs.len() {
241
// TODO: fix preexisting gap--this only handles f32 tensors.
242
let (shape, data): (&ort::tensor::Shape, &[f32]) =
243
session_outputs[i].try_extract_tensor()?;
244
let f32s = data.to_vec();
245
let output = &mut self.outputs[i];
246
let dimensions: Vec<u32> = dimensions_as_u32(shape)?;
247
let tensor = Tensor {
248
dimensions,
249
ty: output.shape.ty,
250
data: f32_vec_to_bytes(f32s),
251
};
252
output.tensor.replace(tensor.clone());
253
output_tensors.push(NamedTensor {
254
name: output.shape.name.clone(),
255
tensor,
256
});
257
}
258
Ok(Some(output_tensors))
259
}
260
261
// WITX
262
None => {
263
let mut session_inputs: Vec<SessionInputValue<'_>> = vec![];
264
for i in &self.inputs {
265
session_inputs.extend(to_input_value(i)?);
266
}
267
let mut session = self.session.lock().unwrap();
268
let session_outputs = session.run(session_inputs.as_slice())?;
269
for i in 0..self.outputs.len() {
270
// TODO: fix preexisting gap--this only handles f32 tensors.
271
let (shape, data): (&ort::tensor::Shape, &[f32]) =
272
session_outputs[i].try_extract_tensor()?;
273
let f32s = data.to_vec();
274
let output = &mut self.outputs[i];
275
let dimensions: Vec<u32> = dimensions_as_u32(shape)?;
276
output.tensor.replace(Tensor {
277
dimensions,
278
ty: output.shape.ty,
279
data: f32_vec_to_bytes(f32s),
280
});
281
}
282
Ok(None)
283
}
284
}
285
}
286
287
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
288
let index = self.find(id, &self.outputs)?;
289
let output = &self.outputs[index];
290
if let Some(tensor) = &output.tensor {
291
Ok(tensor.clone())
292
} else {
293
Err(BackendError::BackendAccess(wasmtime::format_err!(
294
"missing output tensor: {}; has `compute` been called?",
295
output.shape.name
296
)))
297
}
298
}
299
}
300
301
impl From<ort::Error> for BackendError {
302
fn from(e: ort::Error) -> Self {
303
BackendError::BackendAccess(wasmtime::format_err!("{e}"))
304
}
305
}
306
307
/// Holds a slot for ONNX session inputs and outputs.
308
///
309
/// TODO: it seems unfortunate that we have to "hold" some extra data per
310
/// session but in the input case, this is necessary for name-based indexing.
311
struct TensorSlot {
312
shape: Shape,
313
tensor: Option<Tensor>,
314
}
315
316
/// Describes a tensor in ONNX terms.
317
struct Shape {
318
name: String,
319
dimensions: Vec<i64>,
320
ty: TensorType,
321
}
322
323
impl Shape {
324
fn from_onnx_input(input: &Input) -> Result<Self, BackendError> {
325
let name = input.name.clone();
326
let (dimensions, ty) = convert_value_type(&input.input_type)?;
327
Ok(Self {
328
name,
329
dimensions,
330
ty,
331
})
332
}
333
334
fn from_onnx_output(output: &Output) -> Result<Self, BackendError> {
335
let name = output.name.clone();
336
let (dimensions, ty) = convert_value_type(&output.output_type)?;
337
Ok(Self {
338
name,
339
dimensions,
340
ty,
341
})
342
}
343
344
fn matches(&self, tensor: &Tensor) -> wasmtime::Result<()> {
345
if self.dimensions.len() != tensor.dimensions.len() {
346
return Err(wasmtime::format_err!(
347
"input tensor cardinality does not match model: {:?} != {:?}",
348
self.dimensions,
349
tensor.dimensions
350
));
351
} else {
352
for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) {
353
let tensor_dim = tensor_dim as i64;
354
if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim {
355
return Err(wasmtime::format_err!(
356
"input tensor dimensions do not match model: {:?} != {:?}",
357
self.dimensions,
358
tensor.dimensions
359
));
360
}
361
}
362
}
363
if self.ty != tensor.ty {
364
return Err(wasmtime::format_err!(
365
"input tensor type does not match model: {:?} != {:?}",
366
self.ty,
367
tensor.ty
368
));
369
}
370
Ok(())
371
}
372
}
373
374
fn convert_value_type(vt: &ValueType) -> Result<(Vec<i64>, TensorType), BackendError> {
375
match vt {
376
ValueType::Tensor { ty, shape, .. } => {
377
let dimensions = shape.to_vec();
378
let ty = (*ty).try_into()?;
379
Ok((dimensions, ty))
380
}
381
_ => Err(BackendError::BackendAccess(wasmtime::format_err!(
382
"unsupported input type: {vt:?}"
383
))),
384
}
385
}
386
387
fn convert_i64(i: &i64) -> Result<u32, BackendError> {
388
u32::try_from(*i).map_err(|d| -> BackendError {
389
wasmtime::format_err!("unable to convert dimension to u32: {d}").into()
390
})
391
}
392
393
impl TryFrom<TensorElementType> for TensorType {
394
type Error = BackendError;
395
fn try_from(ty: TensorElementType) -> Result<Self, Self::Error> {
396
match ty {
397
TensorElementType::Float32 => Ok(TensorType::Fp32),
398
TensorElementType::Float64 => Ok(TensorType::Fp64),
399
TensorElementType::Uint8 => Ok(TensorType::U8),
400
TensorElementType::Int32 => Ok(TensorType::I32),
401
TensorElementType::Int64 => Ok(TensorType::I64),
402
_ => Err(BackendError::BackendAccess(wasmtime::format_err!(
403
"unsupported tensor type: {ty:?}"
404
))),
405
}
406
}
407
}
408
409
fn to_input_value(slot: &TensorSlot) -> Result<[SessionInputValue<'_>; 1], BackendError> {
410
match &slot.tensor {
411
Some(tensor) => match tensor.ty {
412
TensorType::Fp32 => {
413
let data = bytes_to_f32_vec(tensor.data.to_vec());
414
let dimensions: Vec<i64> = tensor
415
.dimensions
416
.iter()
417
.map(|d| *d as i64) // TODO: fewer conversions
418
.collect();
419
let ort_tensor = OrtTensor::<f32>::from_array((dimensions, data)).map_err(|e| {
420
BackendError::BackendAccess(wasmtime::format_err!(
421
"failed to create ONNX session input: {e}"
422
))
423
})?;
424
Ok(inputs![ort_tensor])
425
}
426
_ => {
427
unimplemented!("{:?} not supported by ONNX", tensor.ty);
428
}
429
},
430
None => {
431
return Err(BackendError::BackendAccess(wasmtime::format_err!(
432
"missing input tensor: {}",
433
slot.shape.name
434
)));
435
}
436
}
437
}
438
439
pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {
440
let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect();
441
let result: Vec<u8> = chunks.iter().flatten().copied().collect();
442
result
443
}
444
445
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
446
let chunks: Vec<&[u8]> = data.chunks(4).collect();
447
let v: Vec<f32> = chunks
448
.into_iter()
449
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
450
.collect();
451
452
v.into_iter().collect()
453
}
454
455
/// Returns whether the dimension is dynamic.
456
///
457
/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the
458
/// value of a tensor dimension is user-defined, not fixed by the model. This is
459
/// useful for batching up several inference requests, e.g. When `ort` returns a
460
/// dimension of this kind, though, it uses `-1` to indicate that the dimension
461
/// is dynamic.
462
///
463
/// [dimensional variables]:
464
/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes
465
fn is_dynamic_dimension(d: i64) -> bool {
466
d == -1
467
}
468
469