Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/backend/winml.rs
3092 views
1
//! Implements a `wasi-nn` [`BackendInner`] using WinML.
2
//!
3
//! Note that the [docs.rs] documentation for the `windows` crate does have the
4
//! right features turned on to read about the functions used; see Microsoft's
5
//! private documentation instead: [microsoft.github.io/windows-docs-rs].
6
//!
7
//! [docs.rs]: https://docs.rs/windows
8
//! [microsoft.github.io/windows-docs-rs]: https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning
9
10
use crate::backend::{
11
BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
12
NamedTensor,
13
};
14
use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
15
use crate::{ExecutionContext, Graph};
16
use std::{fs::File, io::Read, mem::size_of, path::Path};
17
use windows::AI::MachineLearning::{
18
ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,
19
LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,
20
TensorFeatureDescriptor, TensorFloat, TensorFloat16Bit, TensorInt64Bit, TensorKind,
21
};
22
use windows::Foundation::Collections::IVectorView;
23
use windows::Storage::Streams::{
24
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,
25
};
26
use windows::core::{ComInterface, Error, HSTRING, IInspectable};
27
28
#[derive(Default)]
29
pub struct WinMLBackend();
30
31
impl BackendInner for WinMLBackend {
32
fn encoding(&self) -> GraphEncoding {
33
GraphEncoding::Onnx
34
}
35
36
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
37
if builders.len() != 1 {
38
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));
39
}
40
41
let model_stream = InMemoryRandomAccessStream::new()?;
42
let model_writer = DataWriter::CreateDataWriter(&model_stream)?;
43
model_writer.WriteBytes(&builders[0])?;
44
model_writer.StoreAsync()?;
45
model_writer.FlushAsync()?;
46
let model = LearningModel::LoadFromStream(&RandomAccessStreamReference::CreateFromStream(
47
&model_stream,
48
)?)?;
49
let device_kind = match target {
50
ExecutionTarget::Cpu => LearningModelDeviceKind::Cpu,
51
ExecutionTarget::Gpu => LearningModelDeviceKind::DirectX,
52
ExecutionTarget::Tpu => unimplemented!(),
53
};
54
let graph = WinMLGraph { model, device_kind };
55
56
let box_: Box<dyn BackendGraph> = Box::new(graph);
57
Ok(box_.into())
58
}
59
60
fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {
61
Some(self)
62
}
63
}
64
65
impl BackendFromDir for WinMLBackend {
66
fn load_from_dir(
67
&mut self,
68
path: &Path,
69
target: ExecutionTarget,
70
) -> Result<Graph, BackendError> {
71
let model = read(&path.join("model.onnx"))?;
72
self.load(&[&model], target)
73
}
74
}
75
76
struct WinMLGraph {
77
model: LearningModel,
78
device_kind: LearningModelDeviceKind,
79
}
80
81
unsafe impl Send for WinMLGraph {}
82
unsafe impl Sync for WinMLGraph {}
83
84
impl BackendGraph for WinMLGraph {
85
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
86
let device = LearningModelDevice::Create(self.device_kind)?;
87
let session = LearningModelSession::CreateFromModelOnDevice(&self.model, &device)?;
88
let box_: Box<dyn BackendExecutionContext> = Box::new(WinMLExecutionContext::new(session));
89
Ok(box_.into())
90
}
91
}
92
93
struct WinMLExecutionContext {
94
session: LearningModelSession,
95
binding: LearningModelBinding,
96
result: Option<LearningModelEvaluationResult>,
97
}
98
99
impl WinMLExecutionContext {
100
fn new(session: LearningModelSession) -> Self {
101
Self {
102
binding: LearningModelBinding::CreateFromSession(&session).unwrap(),
103
session,
104
result: None,
105
}
106
}
107
}
108
109
impl WinMLExecutionContext {
110
/// Helper function for finding the internal index of a tensor by [`Id`].
111
fn find(
112
&self,
113
id: Id,
114
list: &IVectorView<ILearningModelFeatureDescriptor>,
115
) -> Result<u32, BackendError> {
116
let index = match id {
117
Id::Index(i) => {
118
if i < list.Size()? {
119
i
120
} else {
121
return Err(BackendError::BackendAccess(wasmtime::format_err!(
122
"incorrect tensor index: {i} >= {}",
123
list.Size()?
124
)));
125
}
126
}
127
Id::Name(name) => list
128
.into_iter()
129
.position(|d| d.Name().unwrap() == name)
130
.ok_or_else(|| {
131
BackendError::BackendAccess(wasmtime::format_err!(
132
"unknown tensor name: {name}"
133
))
134
})? as u32,
135
};
136
Ok(index)
137
}
138
}
139
140
impl BackendExecutionContext for WinMLExecutionContext {
141
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
142
// TODO: Clear previous bindings when needed.
143
144
let input_features = self.session.Model()?.InputFeatures()?;
145
let index = self.find(id, &input_features)?;
146
let input = input_features.GetAt(index)?;
147
148
let inspectable = to_inspectable(tensor)?;
149
self.binding.Bind(&input.Name()?, &inspectable)?;
150
151
Ok(())
152
}
153
154
fn compute(
155
&mut self,
156
inputs: Option<Vec<NamedTensor>>,
157
) -> Result<Option<Vec<NamedTensor>>, BackendError> {
158
match inputs {
159
Some(inputs) => {
160
// Clear previous bindings
161
self.binding = LearningModelBinding::CreateFromSession(&self.session)?;
162
163
let input_features = self.session.Model()?.InputFeatures()?;
164
for input in &inputs {
165
let index = input_features
166
.clone()
167
.into_iter()
168
.position(|d| d.Name().unwrap() == input.name)
169
.ok_or_else(|| {
170
BackendError::BackendAccess(wasmtime::format_err!(
171
"Unknown input tensor name: {}",
172
input.name
173
))
174
})? as u32;
175
176
let input_feature = input_features.GetAt(index)?;
177
let inspectable = to_inspectable(&input.tensor)?;
178
self.binding.Bind(&input_feature.Name()?, &inspectable)?;
179
}
180
181
self.result = Some(self.session.Evaluate(&self.binding, &HSTRING::new())?);
182
183
let output_features = self.session.Model()?.OutputFeatures()?;
184
let mut output_tensors = Vec::new();
185
for i in 0..output_features.Size()? {
186
let output_feature = output_features.GetAt(i)?;
187
let tensor_kind = match output_feature.Kind()? {
188
windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => {
189
output_feature
190
.cast::<TensorFeatureDescriptor>()?
191
.TensorKind()?
192
}
193
_ => unimplemented!(
194
"the WinML backend only supports tensors, found: {:?}",
195
output_feature.Kind()
196
),
197
};
198
let tensor = to_tensor(
199
self.result
200
.as_ref()
201
.unwrap()
202
.Outputs()?
203
.Lookup(&output_feature.Name()?)?,
204
tensor_kind,
205
)?;
206
output_tensors.push(NamedTensor {
207
name: output_feature.Name()?.to_string(),
208
tensor,
209
});
210
}
211
Ok(Some(output_tensors))
212
}
213
None => {
214
self.result = Some(self.session.Evaluate(&self.binding, &HSTRING::new())?);
215
Ok(None)
216
}
217
}
218
}
219
220
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
221
if let Some(result) = &self.result {
222
let output_features = self.session.Model()?.OutputFeatures()?;
223
let index = self.find(id, &output_features)?;
224
let output_feature = output_features.GetAt(index)?;
225
let tensor_kind = match output_feature.Kind()? {
226
windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => output_feature
227
.cast::<TensorFeatureDescriptor>()?
228
.TensorKind()?,
229
_ => unimplemented!(
230
"the WinML backend only supports tensors, found: {:?}",
231
output_feature.Kind()
232
),
233
};
234
let tensor = to_tensor(
235
result.Outputs()?.Lookup(&output_feature.Name()?)?,
236
tensor_kind,
237
);
238
tensor
239
} else {
240
return Err(BackendError::BackendAccess(wasmtime::Error::msg(
241
"Output is not ready.",
242
)));
243
}
244
}
245
}
246
247
/// Read a file into a byte vector.
248
fn read(path: &Path) -> wasmtime::Result<Vec<u8>> {
249
let mut file = File::open(path)?;
250
let mut buffer = vec![];
251
file.read_to_end(&mut buffer)?;
252
Ok(buffer)
253
}
254
255
impl From<windows::core::Error> for BackendError {
256
fn from(e: windows::core::Error) -> Self {
257
BackendError::BackendAccess(wasmtime::Error::new(e))
258
}
259
}
260
261
fn dimensions_as_u32(dimensions: &IVectorView<i64>) -> Result<Vec<u32>, BackendError> {
262
dimensions
263
.into_iter()
264
.map(|d| if d == -1 { Ok(1) } else { convert_i64(d) })
265
.collect()
266
}
267
268
fn convert_i64(i: i64) -> Result<u32, BackendError> {
269
u32::try_from(i).map_err(|d| -> BackendError {
270
wasmtime::format_err!("unable to convert dimension to u32: {d}").into()
271
})
272
}
273
274
// Convert from wasi-nn tensor to WinML tensor.
275
fn to_inspectable(tensor: &Tensor) -> Result<IInspectable, Error> {
276
let shape = IVectorView::<i64>::try_from(
277
tensor
278
.dimensions
279
.iter()
280
.map(|&x| x as i64)
281
.collect::<Vec<i64>>(),
282
)?;
283
match tensor.ty {
284
// f16 is not official supported by stable version of Rust. https://github.com/rust-lang/rust/issues/116909
285
// Therefore we create TensorFloat16Bit from f32 array. https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning/struct.TensorFloat16Bit.html#method.CreateFromArray
286
TensorType::Fp16 => unsafe {
287
let data = std::slice::from_raw_parts(
288
tensor.data.as_ptr().cast::<f32>(),
289
tensor.data.len() / size_of::<f32>(),
290
);
291
check_alignment::<f32>(data);
292
TensorFloat16Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()
293
},
294
TensorType::Fp32 => unsafe {
295
let data = std::slice::from_raw_parts(
296
tensor.data.as_ptr().cast::<f32>(),
297
tensor.data.len() / size_of::<f32>(),
298
);
299
check_alignment::<f32>(data);
300
TensorFloat::CreateFromArray(&shape, data)?.cast::<IInspectable>()
301
},
302
TensorType::I64 => unsafe {
303
let data = std::slice::from_raw_parts(
304
tensor.data.as_ptr().cast::<i64>(),
305
tensor.data.len() / size_of::<i64>(),
306
);
307
check_alignment::<i64>(data);
308
TensorInt64Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()
309
},
310
_ => unimplemented!(),
311
}
312
}
313
314
// Convert from WinML tensor to wasi-nn tensor.
315
fn to_tensor(inspectable: IInspectable, tensor_kind: TensorKind) -> Result<Tensor, BackendError> {
316
let tensor = match tensor_kind {
317
TensorKind::Float16 => {
318
let output_tensor = inspectable.cast::<TensorFloat16Bit>()?;
319
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
320
let view = output_tensor.GetAsVectorView()?;
321
// TODO: Move to f16 when it's available in stable.
322
let data = view.into_iter().flat_map(f32::to_le_bytes).collect();
323
Tensor {
324
ty: TensorType::Fp16,
325
dimensions,
326
data,
327
}
328
}
329
TensorKind::Float => {
330
let output_tensor = inspectable.cast::<TensorFloat>()?;
331
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
332
let view = output_tensor.GetAsVectorView()?;
333
let data = view.into_iter().flat_map(f32::to_le_bytes).collect();
334
Tensor {
335
ty: TensorType::Fp32,
336
dimensions,
337
data,
338
}
339
}
340
TensorKind::Int64 => {
341
let output_tensor = inspectable.cast::<TensorInt64Bit>()?;
342
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
343
let view = output_tensor.GetAsVectorView()?;
344
let data = view.into_iter().flat_map(i64::to_le_bytes).collect();
345
Tensor {
346
ty: TensorType::I64,
347
dimensions,
348
data,
349
}
350
}
351
_ => unimplemented!(),
352
};
353
Ok(tensor)
354
}
355
356
fn check_alignment<T>(data: &[T]) {
357
let (prefix, _slice, suffix) = unsafe { data.align_to::<T>() };
358
assert!(
359
prefix.is_empty() && suffix.is_empty(),
360
"Data is not aligned to {:?}'s alignment",
361
std::any::type_name::<T>()
362
);
363
}
364
365
#[cfg(test)]
366
mod tests {
367
use super::*;
368
369
// Unit tests for different data types. Convert from wasi-nn tensor to WinML tensor and back.
370
#[test]
371
fn fp16() {
372
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
373
let buffer = data
374
.iter()
375
.map(|f| f.to_ne_bytes())
376
.flatten()
377
.collect::<Vec<u8>>();
378
let buffer_copy = buffer.clone();
379
let tensor = Tensor {
380
ty: TensorType::Fp16,
381
dimensions: vec![2, 3],
382
data: buffer_copy,
383
};
384
let inspectable = to_inspectable(&tensor);
385
assert!(inspectable.is_ok());
386
let winml_tensor = inspectable
387
.as_ref()
388
.unwrap()
389
.cast::<TensorFloat16Bit>()
390
.unwrap();
391
let view = winml_tensor.GetAsVectorView().unwrap();
392
assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);
393
// Convert back.
394
let t = to_tensor(inspectable.unwrap(), TensorKind::Float16);
395
assert!(t.as_ref().is_ok());
396
assert_eq!(t.unwrap(), tensor);
397
}
398
399
#[test]
400
fn fp32() {
401
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
402
let mut buffer = Vec::with_capacity(data.len() * size_of::<f32>());
403
for f in &data {
404
buffer.extend(f.to_ne_bytes());
405
}
406
let buffer_copy = buffer.clone();
407
let tensor = Tensor {
408
ty: TensorType::Fp32,
409
dimensions: vec![2, 3],
410
data: buffer_copy,
411
};
412
let inspectable = to_inspectable(&tensor);
413
assert!(inspectable.is_ok());
414
let winml_tensor = inspectable.as_ref().unwrap().cast::<TensorFloat>().unwrap();
415
let view = winml_tensor.GetAsVectorView().unwrap();
416
assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);
417
// Convert back.
418
let t = to_tensor(inspectable.unwrap(), TensorKind::Float);
419
assert!(t.as_ref().is_ok());
420
assert_eq!(t.unwrap(), tensor);
421
}
422
423
#[test]
424
fn i64() {
425
let data = vec![6i64, 5, 4, 3, 2, 1];
426
let mut buffer = Vec::with_capacity(data.len() * size_of::<i64>());
427
for f in &data {
428
buffer.extend(f.to_ne_bytes());
429
}
430
let buffer_copy = buffer.clone();
431
let tensor = Tensor {
432
ty: TensorType::I64,
433
dimensions: vec![1, 6],
434
data: buffer_copy,
435
};
436
let inspectable = to_inspectable(&tensor);
437
assert!(inspectable.is_ok());
438
let winml_tensor = inspectable
439
.as_ref()
440
.unwrap()
441
.cast::<TensorInt64Bit>()
442
.unwrap();
443
let view = winml_tensor.GetAsVectorView().unwrap();
444
assert_eq!(view.into_iter().collect::<Vec<i64>>(), data);
445
// Convert back.
446
let t = to_tensor(inspectable.unwrap(), TensorKind::Int64);
447
assert!(t.as_ref().is_ok());
448
assert_eq!(t.unwrap(), tensor);
449
}
450
}
451
452