Path: blob/main/crates/wasi-nn/src/backend/winml.rs
2459 views
//! Implements a `wasi-nn` [`BackendInner`] using WinML.1//!2//! Note that the [docs.rs] documentation for the `windows` crate does have the3//! right features turned on to read about the functions used; see Microsoft's4//! private documentation instead: [microsoft.github.io/windows-docs-rs].5//!6//! [docs.rs]: https://docs.rs/windows7//! [microsoft.github.io/windows-docs-rs]: https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning89use crate::backend::{10BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,11NamedTensor,12};13use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};14use crate::{ExecutionContext, Graph};15use std::{fs::File, io::Read, mem::size_of, path::Path};16use windows::AI::MachineLearning::{17ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,18LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,19TensorFeatureDescriptor, TensorFloat, TensorFloat16Bit, TensorInt64Bit, TensorKind,20};21use windows::Foundation::Collections::IVectorView;22use windows::Storage::Streams::{23DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,24};25use windows::core::{ComInterface, Error, HSTRING, IInspectable};2627#[derive(Default)]28pub struct WinMLBackend();2930impl BackendInner for WinMLBackend {31fn encoding(&self) -> GraphEncoding {32GraphEncoding::Onnx33}3435fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {36if builders.len() != 1 {37return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));38}3940let model_stream = InMemoryRandomAccessStream::new()?;41let model_writer = DataWriter::CreateDataWriter(&model_stream)?;42model_writer.WriteBytes(&builders[0])?;43model_writer.StoreAsync()?;44model_writer.FlushAsync()?;45let model = LearningModel::LoadFromStream(&RandomAccessStreamReference::CreateFromStream(46&model_stream,47)?)?;48let device_kind = match target {49ExecutionTarget::Cpu => LearningModelDeviceKind::Cpu,50ExecutionTarget::Gpu => LearningModelDeviceKind::DirectX,51ExecutionTarget::Tpu => unimplemented!(),52};53let graph = WinMLGraph { model, device_kind };5455let box_: Box<dyn BackendGraph> = Box::new(graph);56Ok(box_.into())57}5859fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {60Some(self)61}62}6364impl BackendFromDir for WinMLBackend {65fn load_from_dir(66&mut self,67path: &Path,68target: ExecutionTarget,69) -> Result<Graph, BackendError> {70let model = read(&path.join("model.onnx"))?;71self.load(&[&model], target)72}73}7475struct WinMLGraph {76model: LearningModel,77device_kind: LearningModelDeviceKind,78}7980unsafe impl Send for WinMLGraph {}81unsafe impl Sync for WinMLGraph {}8283impl BackendGraph for WinMLGraph {84fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {85let device = LearningModelDevice::Create(self.device_kind)?;86let session = LearningModelSession::CreateFromModelOnDevice(&self.model, &device)?;87let box_: Box<dyn BackendExecutionContext> = Box::new(WinMLExecutionContext::new(session));88Ok(box_.into())89}90}9192struct WinMLExecutionContext {93session: LearningModelSession,94binding: LearningModelBinding,95result: Option<LearningModelEvaluationResult>,96}9798impl WinMLExecutionContext {99fn new(session: LearningModelSession) -> Self {100Self {101binding: LearningModelBinding::CreateFromSession(&session).unwrap(),102session,103result: None,104}105}106}107108impl WinMLExecutionContext {109/// Helper function for finding the internal index of a tensor by [`Id`].110fn find(111&self,112id: Id,113list: &IVectorView<ILearningModelFeatureDescriptor>,114) -> Result<u32, BackendError> {115let index = match id {116Id::Index(i) => {117if i < list.Size()? {118i119} else {120return Err(BackendError::BackendAccess(anyhow::anyhow!(121"incorrect tensor index: {i} >= {}",122list.Size()?123)));124}125}126Id::Name(name) => list127.into_iter()128.position(|d| d.Name().unwrap() == name)129.ok_or_else(|| {130BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {name}"))131})? as u32,132};133Ok(index)134}135}136137impl BackendExecutionContext for WinMLExecutionContext {138fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {139// TODO: Clear previous bindings when needed.140141let input_features = self.session.Model()?.InputFeatures()?;142let index = self.find(id, &input_features)?;143let input = input_features.GetAt(index)?;144145let inspectable = to_inspectable(tensor)?;146self.binding.Bind(&input.Name()?, &inspectable)?;147148Ok(())149}150151fn compute(152&mut self,153inputs: Option<Vec<NamedTensor>>,154) -> Result<Option<Vec<NamedTensor>>, BackendError> {155match inputs {156Some(inputs) => {157// Clear previous bindings158self.binding = LearningModelBinding::CreateFromSession(&self.session)?;159160let input_features = self.session.Model()?.InputFeatures()?;161for input in &inputs {162let index = input_features163.clone()164.into_iter()165.position(|d| d.Name().unwrap() == input.name)166.ok_or_else(|| {167BackendError::BackendAccess(anyhow::anyhow!(168"Unknown input tensor name: {}",169input.name170))171})? as u32;172173let input_feature = input_features.GetAt(index)?;174let inspectable = to_inspectable(&input.tensor)?;175self.binding.Bind(&input_feature.Name()?, &inspectable)?;176}177178self.result = Some(self.session.Evaluate(&self.binding, &HSTRING::new())?);179180let output_features = self.session.Model()?.OutputFeatures()?;181let mut output_tensors = Vec::new();182for i in 0..output_features.Size()? {183let output_feature = output_features.GetAt(i)?;184let tensor_kind = match output_feature.Kind()? {185windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => {186output_feature187.cast::<TensorFeatureDescriptor>()?188.TensorKind()?189}190_ => unimplemented!(191"the WinML backend only supports tensors, found: {:?}",192output_feature.Kind()193),194};195let tensor = to_tensor(196self.result197.as_ref()198.unwrap()199.Outputs()?200.Lookup(&output_feature.Name()?)?,201tensor_kind,202)?;203output_tensors.push(NamedTensor {204name: output_feature.Name()?.to_string(),205tensor,206});207}208Ok(Some(output_tensors))209}210None => {211self.result = Some(self.session.Evaluate(&self.binding, &HSTRING::new())?);212Ok(None)213}214}215}216217fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {218if let Some(result) = &self.result {219let output_features = self.session.Model()?.OutputFeatures()?;220let index = self.find(id, &output_features)?;221let output_feature = output_features.GetAt(index)?;222let tensor_kind = match output_feature.Kind()? {223windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => output_feature224.cast::<TensorFeatureDescriptor>()?225.TensorKind()?,226_ => unimplemented!(227"the WinML backend only supports tensors, found: {:?}",228output_feature.Kind()229),230};231let tensor = to_tensor(232result.Outputs()?.Lookup(&output_feature.Name()?)?,233tensor_kind,234);235tensor236} else {237return Err(BackendError::BackendAccess(anyhow::Error::msg(238"Output is not ready.",239)));240}241}242}243244/// Read a file into a byte vector.245fn read(path: &Path) -> anyhow::Result<Vec<u8>> {246let mut file = File::open(path)?;247let mut buffer = vec![];248file.read_to_end(&mut buffer)?;249Ok(buffer)250}251252impl From<windows::core::Error> for BackendError {253fn from(e: windows::core::Error) -> Self {254BackendError::BackendAccess(anyhow::Error::new(e))255}256}257258fn dimensions_as_u32(dimensions: &IVectorView<i64>) -> Result<Vec<u32>, BackendError> {259dimensions260.into_iter()261.map(|d| if d == -1 { Ok(1) } else { convert_i64(d) })262.collect()263}264265fn convert_i64(i: i64) -> Result<u32, BackendError> {266u32::try_from(i).map_err(|d| -> BackendError {267anyhow::anyhow!("unable to convert dimension to u32: {d}").into()268})269}270271// Convert from wasi-nn tensor to WinML tensor.272fn to_inspectable(tensor: &Tensor) -> Result<IInspectable, Error> {273let shape = IVectorView::<i64>::try_from(274tensor275.dimensions276.iter()277.map(|&x| x as i64)278.collect::<Vec<i64>>(),279)?;280match tensor.ty {281// f16 is not official supported by stable version of Rust. https://github.com/rust-lang/rust/issues/116909282// Therefore we create TensorFloat16Bit from f32 array. https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning/struct.TensorFloat16Bit.html#method.CreateFromArray283TensorType::Fp16 => unsafe {284let data = std::slice::from_raw_parts(285tensor.data.as_ptr().cast::<f32>(),286tensor.data.len() / size_of::<f32>(),287);288check_alignment::<f32>(data);289TensorFloat16Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()290},291TensorType::Fp32 => unsafe {292let data = std::slice::from_raw_parts(293tensor.data.as_ptr().cast::<f32>(),294tensor.data.len() / size_of::<f32>(),295);296check_alignment::<f32>(data);297TensorFloat::CreateFromArray(&shape, data)?.cast::<IInspectable>()298},299TensorType::I64 => unsafe {300let data = std::slice::from_raw_parts(301tensor.data.as_ptr().cast::<i64>(),302tensor.data.len() / size_of::<i64>(),303);304check_alignment::<i64>(data);305TensorInt64Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()306},307_ => unimplemented!(),308}309}310311// Convert from WinML tensor to wasi-nn tensor.312fn to_tensor(inspectable: IInspectable, tensor_kind: TensorKind) -> Result<Tensor, BackendError> {313let tensor = match tensor_kind {314TensorKind::Float16 => {315let output_tensor = inspectable.cast::<TensorFloat16Bit>()?;316let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;317let view = output_tensor.GetAsVectorView()?;318// TODO: Move to f16 when it's available in stable.319let data = view.into_iter().flat_map(f32::to_le_bytes).collect();320Tensor {321ty: TensorType::Fp16,322dimensions,323data,324}325}326TensorKind::Float => {327let output_tensor = inspectable.cast::<TensorFloat>()?;328let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;329let view = output_tensor.GetAsVectorView()?;330let data = view.into_iter().flat_map(f32::to_le_bytes).collect();331Tensor {332ty: TensorType::Fp32,333dimensions,334data,335}336}337TensorKind::Int64 => {338let output_tensor = inspectable.cast::<TensorInt64Bit>()?;339let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;340let view = output_tensor.GetAsVectorView()?;341let data = view.into_iter().flat_map(i64::to_le_bytes).collect();342Tensor {343ty: TensorType::I64,344dimensions,345data,346}347}348_ => unimplemented!(),349};350Ok(tensor)351}352353fn check_alignment<T>(data: &[T]) {354let (prefix, _slice, suffix) = unsafe { data.align_to::<T>() };355assert!(356prefix.is_empty() && suffix.is_empty(),357"Data is not aligned to {:?}'s alignment",358std::any::type_name::<T>()359);360}361362#[cfg(test)]363mod tests {364use super::*;365366// Unit tests for different data types. Convert from wasi-nn tensor to WinML tensor and back.367#[test]368fn fp16() {369let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];370let buffer = data371.iter()372.map(|f| f.to_ne_bytes())373.flatten()374.collect::<Vec<u8>>();375let buffer_copy = buffer.clone();376let tensor = Tensor {377ty: TensorType::Fp16,378dimensions: vec![2, 3],379data: buffer_copy,380};381let inspectable = to_inspectable(&tensor);382assert!(inspectable.is_ok());383let winml_tensor = inspectable384.as_ref()385.unwrap()386.cast::<TensorFloat16Bit>()387.unwrap();388let view = winml_tensor.GetAsVectorView().unwrap();389assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);390// Convert back.391let t = to_tensor(inspectable.unwrap(), TensorKind::Float16);392assert!(t.as_ref().is_ok());393assert_eq!(t.unwrap(), tensor);394}395396#[test]397fn fp32() {398let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];399let mut buffer = Vec::with_capacity(data.len() * size_of::<f32>());400for f in &data {401buffer.extend(f.to_ne_bytes());402}403let buffer_copy = buffer.clone();404let tensor = Tensor {405ty: TensorType::Fp32,406dimensions: vec![2, 3],407data: buffer_copy,408};409let inspectable = to_inspectable(&tensor);410assert!(inspectable.is_ok());411let winml_tensor = inspectable.as_ref().unwrap().cast::<TensorFloat>().unwrap();412let view = winml_tensor.GetAsVectorView().unwrap();413assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);414// Convert back.415let t = to_tensor(inspectable.unwrap(), TensorKind::Float);416assert!(t.as_ref().is_ok());417assert_eq!(t.unwrap(), tensor);418}419420#[test]421fn i64() {422let data = vec![6i64, 5, 4, 3, 2, 1];423let mut buffer = Vec::with_capacity(data.len() * size_of::<i64>());424for f in &data {425buffer.extend(f.to_ne_bytes());426}427let buffer_copy = buffer.clone();428let tensor = Tensor {429ty: TensorType::I64,430dimensions: vec![1, 6],431data: buffer_copy,432};433let inspectable = to_inspectable(&tensor);434assert!(inspectable.is_ok());435let winml_tensor = inspectable436.as_ref()437.unwrap()438.cast::<TensorInt64Bit>()439.unwrap();440let view = winml_tensor.GetAsVectorView().unwrap();441assert_eq!(view.into_iter().collect::<Vec<i64>>(), data);442// Convert back.443let t = to_tensor(inspectable.unwrap(), TensorKind::Int64);444assert!(t.as_ref().is_ok());445assert_eq!(t.unwrap(), tensor);446}447}448449450