Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/backend/winml.rs
2459 views
1
//! Implements a `wasi-nn` [`BackendInner`] using WinML.
2
//!
3
//! Note that the [docs.rs] documentation for the `windows` crate does have the
4
//! right features turned on to read about the functions used; see Microsoft's
5
//! private documentation instead: [microsoft.github.io/windows-docs-rs].
6
//!
7
//! [docs.rs]: https://docs.rs/windows
8
//! [microsoft.github.io/windows-docs-rs]: https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning
9
10
use crate::backend::{
11
BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
12
NamedTensor,
13
};
14
use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
15
use crate::{ExecutionContext, Graph};
16
use std::{fs::File, io::Read, mem::size_of, path::Path};
17
use windows::AI::MachineLearning::{
18
ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,
19
LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,
20
TensorFeatureDescriptor, TensorFloat, TensorFloat16Bit, TensorInt64Bit, TensorKind,
21
};
22
use windows::Foundation::Collections::IVectorView;
23
use windows::Storage::Streams::{
24
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,
25
};
26
use windows::core::{ComInterface, Error, HSTRING, IInspectable};
27
28
#[derive(Default)]
29
pub struct WinMLBackend();
30
31
impl BackendInner for WinMLBackend {
32
fn encoding(&self) -> GraphEncoding {
33
GraphEncoding::Onnx
34
}
35
36
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
37
if builders.len() != 1 {
38
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));
39
}
40
41
let model_stream = InMemoryRandomAccessStream::new()?;
42
let model_writer = DataWriter::CreateDataWriter(&model_stream)?;
43
model_writer.WriteBytes(&builders[0])?;
44
model_writer.StoreAsync()?;
45
model_writer.FlushAsync()?;
46
let model = LearningModel::LoadFromStream(&RandomAccessStreamReference::CreateFromStream(
47
&model_stream,
48
)?)?;
49
let device_kind = match target {
50
ExecutionTarget::Cpu => LearningModelDeviceKind::Cpu,
51
ExecutionTarget::Gpu => LearningModelDeviceKind::DirectX,
52
ExecutionTarget::Tpu => unimplemented!(),
53
};
54
let graph = WinMLGraph { model, device_kind };
55
56
let box_: Box<dyn BackendGraph> = Box::new(graph);
57
Ok(box_.into())
58
}
59
60
fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {
61
Some(self)
62
}
63
}
64
65
impl BackendFromDir for WinMLBackend {
66
fn load_from_dir(
67
&mut self,
68
path: &Path,
69
target: ExecutionTarget,
70
) -> Result<Graph, BackendError> {
71
let model = read(&path.join("model.onnx"))?;
72
self.load(&[&model], target)
73
}
74
}
75
76
struct WinMLGraph {
77
model: LearningModel,
78
device_kind: LearningModelDeviceKind,
79
}
80
81
unsafe impl Send for WinMLGraph {}
82
unsafe impl Sync for WinMLGraph {}
83
84
impl BackendGraph for WinMLGraph {
85
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
86
let device = LearningModelDevice::Create(self.device_kind)?;
87
let session = LearningModelSession::CreateFromModelOnDevice(&self.model, &device)?;
88
let box_: Box<dyn BackendExecutionContext> = Box::new(WinMLExecutionContext::new(session));
89
Ok(box_.into())
90
}
91
}
92
93
struct WinMLExecutionContext {
94
session: LearningModelSession,
95
binding: LearningModelBinding,
96
result: Option<LearningModelEvaluationResult>,
97
}
98
99
impl WinMLExecutionContext {
100
fn new(session: LearningModelSession) -> Self {
101
Self {
102
binding: LearningModelBinding::CreateFromSession(&session).unwrap(),
103
session,
104
result: None,
105
}
106
}
107
}
108
109
impl WinMLExecutionContext {
110
/// Helper function for finding the internal index of a tensor by [`Id`].
111
fn find(
112
&self,
113
id: Id,
114
list: &IVectorView<ILearningModelFeatureDescriptor>,
115
) -> Result<u32, BackendError> {
116
let index = match id {
117
Id::Index(i) => {
118
if i < list.Size()? {
119
i
120
} else {
121
return Err(BackendError::BackendAccess(anyhow::anyhow!(
122
"incorrect tensor index: {i} >= {}",
123
list.Size()?
124
)));
125
}
126
}
127
Id::Name(name) => list
128
.into_iter()
129
.position(|d| d.Name().unwrap() == name)
130
.ok_or_else(|| {
131
BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {name}"))
132
})? as u32,
133
};
134
Ok(index)
135
}
136
}
137
138
impl BackendExecutionContext for WinMLExecutionContext {
139
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
140
// TODO: Clear previous bindings when needed.
141
142
let input_features = self.session.Model()?.InputFeatures()?;
143
let index = self.find(id, &input_features)?;
144
let input = input_features.GetAt(index)?;
145
146
let inspectable = to_inspectable(tensor)?;
147
self.binding.Bind(&input.Name()?, &inspectable)?;
148
149
Ok(())
150
}
151
152
fn compute(
153
&mut self,
154
inputs: Option<Vec<NamedTensor>>,
155
) -> Result<Option<Vec<NamedTensor>>, BackendError> {
156
match inputs {
157
Some(inputs) => {
158
// Clear previous bindings
159
self.binding = LearningModelBinding::CreateFromSession(&self.session)?;
160
161
let input_features = self.session.Model()?.InputFeatures()?;
162
for input in &inputs {
163
let index = input_features
164
.clone()
165
.into_iter()
166
.position(|d| d.Name().unwrap() == input.name)
167
.ok_or_else(|| {
168
BackendError::BackendAccess(anyhow::anyhow!(
169
"Unknown input tensor name: {}",
170
input.name
171
))
172
})? as u32;
173
174
let input_feature = input_features.GetAt(index)?;
175
let inspectable = to_inspectable(&input.tensor)?;
176
self.binding.Bind(&input_feature.Name()?, &inspectable)?;
177
}
178
179
self.result = Some(self.session.Evaluate(&self.binding, &HSTRING::new())?);
180
181
let output_features = self.session.Model()?.OutputFeatures()?;
182
let mut output_tensors = Vec::new();
183
for i in 0..output_features.Size()? {
184
let output_feature = output_features.GetAt(i)?;
185
let tensor_kind = match output_feature.Kind()? {
186
windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => {
187
output_feature
188
.cast::<TensorFeatureDescriptor>()?
189
.TensorKind()?
190
}
191
_ => unimplemented!(
192
"the WinML backend only supports tensors, found: {:?}",
193
output_feature.Kind()
194
),
195
};
196
let tensor = to_tensor(
197
self.result
198
.as_ref()
199
.unwrap()
200
.Outputs()?
201
.Lookup(&output_feature.Name()?)?,
202
tensor_kind,
203
)?;
204
output_tensors.push(NamedTensor {
205
name: output_feature.Name()?.to_string(),
206
tensor,
207
});
208
}
209
Ok(Some(output_tensors))
210
}
211
None => {
212
self.result = Some(self.session.Evaluate(&self.binding, &HSTRING::new())?);
213
Ok(None)
214
}
215
}
216
}
217
218
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
219
if let Some(result) = &self.result {
220
let output_features = self.session.Model()?.OutputFeatures()?;
221
let index = self.find(id, &output_features)?;
222
let output_feature = output_features.GetAt(index)?;
223
let tensor_kind = match output_feature.Kind()? {
224
windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => output_feature
225
.cast::<TensorFeatureDescriptor>()?
226
.TensorKind()?,
227
_ => unimplemented!(
228
"the WinML backend only supports tensors, found: {:?}",
229
output_feature.Kind()
230
),
231
};
232
let tensor = to_tensor(
233
result.Outputs()?.Lookup(&output_feature.Name()?)?,
234
tensor_kind,
235
);
236
tensor
237
} else {
238
return Err(BackendError::BackendAccess(anyhow::Error::msg(
239
"Output is not ready.",
240
)));
241
}
242
}
243
}
244
245
/// Read a file into a byte vector.
246
fn read(path: &Path) -> anyhow::Result<Vec<u8>> {
247
let mut file = File::open(path)?;
248
let mut buffer = vec![];
249
file.read_to_end(&mut buffer)?;
250
Ok(buffer)
251
}
252
253
impl From<windows::core::Error> for BackendError {
254
fn from(e: windows::core::Error) -> Self {
255
BackendError::BackendAccess(anyhow::Error::new(e))
256
}
257
}
258
259
fn dimensions_as_u32(dimensions: &IVectorView<i64>) -> Result<Vec<u32>, BackendError> {
260
dimensions
261
.into_iter()
262
.map(|d| if d == -1 { Ok(1) } else { convert_i64(d) })
263
.collect()
264
}
265
266
fn convert_i64(i: i64) -> Result<u32, BackendError> {
267
u32::try_from(i).map_err(|d| -> BackendError {
268
anyhow::anyhow!("unable to convert dimension to u32: {d}").into()
269
})
270
}
271
272
// Convert from wasi-nn tensor to WinML tensor.
273
fn to_inspectable(tensor: &Tensor) -> Result<IInspectable, Error> {
274
let shape = IVectorView::<i64>::try_from(
275
tensor
276
.dimensions
277
.iter()
278
.map(|&x| x as i64)
279
.collect::<Vec<i64>>(),
280
)?;
281
match tensor.ty {
282
// f16 is not official supported by stable version of Rust. https://github.com/rust-lang/rust/issues/116909
283
// Therefore we create TensorFloat16Bit from f32 array. https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning/struct.TensorFloat16Bit.html#method.CreateFromArray
284
TensorType::Fp16 => unsafe {
285
let data = std::slice::from_raw_parts(
286
tensor.data.as_ptr().cast::<f32>(),
287
tensor.data.len() / size_of::<f32>(),
288
);
289
check_alignment::<f32>(data);
290
TensorFloat16Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()
291
},
292
TensorType::Fp32 => unsafe {
293
let data = std::slice::from_raw_parts(
294
tensor.data.as_ptr().cast::<f32>(),
295
tensor.data.len() / size_of::<f32>(),
296
);
297
check_alignment::<f32>(data);
298
TensorFloat::CreateFromArray(&shape, data)?.cast::<IInspectable>()
299
},
300
TensorType::I64 => unsafe {
301
let data = std::slice::from_raw_parts(
302
tensor.data.as_ptr().cast::<i64>(),
303
tensor.data.len() / size_of::<i64>(),
304
);
305
check_alignment::<i64>(data);
306
TensorInt64Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()
307
},
308
_ => unimplemented!(),
309
}
310
}
311
312
// Convert from WinML tensor to wasi-nn tensor.
313
fn to_tensor(inspectable: IInspectable, tensor_kind: TensorKind) -> Result<Tensor, BackendError> {
314
let tensor = match tensor_kind {
315
TensorKind::Float16 => {
316
let output_tensor = inspectable.cast::<TensorFloat16Bit>()?;
317
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
318
let view = output_tensor.GetAsVectorView()?;
319
// TODO: Move to f16 when it's available in stable.
320
let data = view.into_iter().flat_map(f32::to_le_bytes).collect();
321
Tensor {
322
ty: TensorType::Fp16,
323
dimensions,
324
data,
325
}
326
}
327
TensorKind::Float => {
328
let output_tensor = inspectable.cast::<TensorFloat>()?;
329
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
330
let view = output_tensor.GetAsVectorView()?;
331
let data = view.into_iter().flat_map(f32::to_le_bytes).collect();
332
Tensor {
333
ty: TensorType::Fp32,
334
dimensions,
335
data,
336
}
337
}
338
TensorKind::Int64 => {
339
let output_tensor = inspectable.cast::<TensorInt64Bit>()?;
340
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
341
let view = output_tensor.GetAsVectorView()?;
342
let data = view.into_iter().flat_map(i64::to_le_bytes).collect();
343
Tensor {
344
ty: TensorType::I64,
345
dimensions,
346
data,
347
}
348
}
349
_ => unimplemented!(),
350
};
351
Ok(tensor)
352
}
353
354
fn check_alignment<T>(data: &[T]) {
355
let (prefix, _slice, suffix) = unsafe { data.align_to::<T>() };
356
assert!(
357
prefix.is_empty() && suffix.is_empty(),
358
"Data is not aligned to {:?}'s alignment",
359
std::any::type_name::<T>()
360
);
361
}
362
363
#[cfg(test)]
364
mod tests {
365
use super::*;
366
367
// Unit tests for different data types. Convert from wasi-nn tensor to WinML tensor and back.
368
#[test]
369
fn fp16() {
370
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
371
let buffer = data
372
.iter()
373
.map(|f| f.to_ne_bytes())
374
.flatten()
375
.collect::<Vec<u8>>();
376
let buffer_copy = buffer.clone();
377
let tensor = Tensor {
378
ty: TensorType::Fp16,
379
dimensions: vec![2, 3],
380
data: buffer_copy,
381
};
382
let inspectable = to_inspectable(&tensor);
383
assert!(inspectable.is_ok());
384
let winml_tensor = inspectable
385
.as_ref()
386
.unwrap()
387
.cast::<TensorFloat16Bit>()
388
.unwrap();
389
let view = winml_tensor.GetAsVectorView().unwrap();
390
assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);
391
// Convert back.
392
let t = to_tensor(inspectable.unwrap(), TensorKind::Float16);
393
assert!(t.as_ref().is_ok());
394
assert_eq!(t.unwrap(), tensor);
395
}
396
397
#[test]
398
fn fp32() {
399
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
400
let mut buffer = Vec::with_capacity(data.len() * size_of::<f32>());
401
for f in &data {
402
buffer.extend(f.to_ne_bytes());
403
}
404
let buffer_copy = buffer.clone();
405
let tensor = Tensor {
406
ty: TensorType::Fp32,
407
dimensions: vec![2, 3],
408
data: buffer_copy,
409
};
410
let inspectable = to_inspectable(&tensor);
411
assert!(inspectable.is_ok());
412
let winml_tensor = inspectable.as_ref().unwrap().cast::<TensorFloat>().unwrap();
413
let view = winml_tensor.GetAsVectorView().unwrap();
414
assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);
415
// Convert back.
416
let t = to_tensor(inspectable.unwrap(), TensorKind::Float);
417
assert!(t.as_ref().is_ok());
418
assert_eq!(t.unwrap(), tensor);
419
}
420
421
#[test]
422
fn i64() {
423
let data = vec![6i64, 5, 4, 3, 2, 1];
424
let mut buffer = Vec::with_capacity(data.len() * size_of::<i64>());
425
for f in &data {
426
buffer.extend(f.to_ne_bytes());
427
}
428
let buffer_copy = buffer.clone();
429
let tensor = Tensor {
430
ty: TensorType::I64,
431
dimensions: vec![1, 6],
432
data: buffer_copy,
433
};
434
let inspectable = to_inspectable(&tensor);
435
assert!(inspectable.is_ok());
436
let winml_tensor = inspectable
437
.as_ref()
438
.unwrap()
439
.cast::<TensorInt64Bit>()
440
.unwrap();
441
let view = winml_tensor.GetAsVectorView().unwrap();
442
assert_eq!(view.into_iter().collect::<Vec<i64>>(), data);
443
// Convert back.
444
let t = to_tensor(inspectable.unwrap(), TensorKind::Int64);
445
assert!(t.as_ref().is_ok());
446
assert_eq!(t.unwrap(), tensor);
447
}
448
}
449
450