Path: blob/main/crates/wasi-nn/src/backend/winml.rs
3092 views
//! Implements a `wasi-nn` [`BackendInner`] using WinML.1//!2//! Note that the [docs.rs] documentation for the `windows` crate does have the3//! right features turned on to read about the functions used; see Microsoft's4//! private documentation instead: [microsoft.github.io/windows-docs-rs].5//!6//! [docs.rs]: https://docs.rs/windows7//! [microsoft.github.io/windows-docs-rs]: https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning89use crate::backend::{10BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,11NamedTensor,12};13use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};14use crate::{ExecutionContext, Graph};15use std::{fs::File, io::Read, mem::size_of, path::Path};16use windows::AI::MachineLearning::{17ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,18LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,19TensorFeatureDescriptor, TensorFloat, TensorFloat16Bit, TensorInt64Bit, TensorKind,20};21use windows::Foundation::Collections::IVectorView;22use windows::Storage::Streams::{23DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,24};25use windows::core::{ComInterface, Error, HSTRING, IInspectable};2627#[derive(Default)]28pub struct WinMLBackend();2930impl BackendInner for WinMLBackend {31fn encoding(&self) -> GraphEncoding {32GraphEncoding::Onnx33}3435fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {36if builders.len() != 1 {37return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()));38}3940let model_stream = InMemoryRandomAccessStream::new()?;41let model_writer = DataWriter::CreateDataWriter(&model_stream)?;42model_writer.WriteBytes(&builders[0])?;43model_writer.StoreAsync()?;44model_writer.FlushAsync()?;45let model = LearningModel::LoadFromStream(&RandomAccessStreamReference::CreateFromStream(46&model_stream,47)?)?;48let device_kind = match target {49ExecutionTarget::Cpu => LearningModelDeviceKind::Cpu,50ExecutionTarget::Gpu => LearningModelDeviceKind::DirectX,51ExecutionTarget::Tpu => unimplemented!(),52};53let graph = WinMLGraph { model, device_kind };5455let box_: Box<dyn BackendGraph> = Box::new(graph);56Ok(box_.into())57}5859fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {60Some(self)61}62}6364impl BackendFromDir for WinMLBackend {65fn load_from_dir(66&mut self,67path: &Path,68target: ExecutionTarget,69) -> Result<Graph, BackendError> {70let model = read(&path.join("model.onnx"))?;71self.load(&[&model], target)72}73}7475struct WinMLGraph {76model: LearningModel,77device_kind: LearningModelDeviceKind,78}7980unsafe impl Send for WinMLGraph {}81unsafe impl Sync for WinMLGraph {}8283impl BackendGraph for WinMLGraph {84fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {85let device = LearningModelDevice::Create(self.device_kind)?;86let session = LearningModelSession::CreateFromModelOnDevice(&self.model, &device)?;87let box_: Box<dyn BackendExecutionContext> = Box::new(WinMLExecutionContext::new(session));88Ok(box_.into())89}90}9192struct WinMLExecutionContext {93session: LearningModelSession,94binding: LearningModelBinding,95result: Option<LearningModelEvaluationResult>,96}9798impl WinMLExecutionContext {99fn new(session: LearningModelSession) -> Self {100Self {101binding: LearningModelBinding::CreateFromSession(&session).unwrap(),102session,103result: None,104}105}106}107108impl WinMLExecutionContext {109/// Helper function for finding the internal index of a tensor by [`Id`].110fn find(111&self,112id: Id,113list: &IVectorView<ILearningModelFeatureDescriptor>,114) -> Result<u32, BackendError> {115let index = match id {116Id::Index(i) => {117if i < list.Size()? {118i119} else {120return Err(BackendError::BackendAccess(wasmtime::format_err!(121"incorrect tensor index: {i} >= {}",122list.Size()?123)));124}125}126Id::Name(name) => list127.into_iter()128.position(|d| d.Name().unwrap() == name)129.ok_or_else(|| {130BackendError::BackendAccess(wasmtime::format_err!(131"unknown tensor name: {name}"132))133})? as u32,134};135Ok(index)136}137}138139impl BackendExecutionContext for WinMLExecutionContext {140fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {141// TODO: Clear previous bindings when needed.142143let input_features = self.session.Model()?.InputFeatures()?;144let index = self.find(id, &input_features)?;145let input = input_features.GetAt(index)?;146147let inspectable = to_inspectable(tensor)?;148self.binding.Bind(&input.Name()?, &inspectable)?;149150Ok(())151}152153fn compute(154&mut self,155inputs: Option<Vec<NamedTensor>>,156) -> Result<Option<Vec<NamedTensor>>, BackendError> {157match inputs {158Some(inputs) => {159// Clear previous bindings160self.binding = LearningModelBinding::CreateFromSession(&self.session)?;161162let input_features = self.session.Model()?.InputFeatures()?;163for input in &inputs {164let index = input_features165.clone()166.into_iter()167.position(|d| d.Name().unwrap() == input.name)168.ok_or_else(|| {169BackendError::BackendAccess(wasmtime::format_err!(170"Unknown input tensor name: {}",171input.name172))173})? as u32;174175let input_feature = input_features.GetAt(index)?;176let inspectable = to_inspectable(&input.tensor)?;177self.binding.Bind(&input_feature.Name()?, &inspectable)?;178}179180self.result = Some(self.session.Evaluate(&self.binding, &HSTRING::new())?);181182let output_features = self.session.Model()?.OutputFeatures()?;183let mut output_tensors = Vec::new();184for i in 0..output_features.Size()? {185let output_feature = output_features.GetAt(i)?;186let tensor_kind = match output_feature.Kind()? {187windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => {188output_feature189.cast::<TensorFeatureDescriptor>()?190.TensorKind()?191}192_ => unimplemented!(193"the WinML backend only supports tensors, found: {:?}",194output_feature.Kind()195),196};197let tensor = to_tensor(198self.result199.as_ref()200.unwrap()201.Outputs()?202.Lookup(&output_feature.Name()?)?,203tensor_kind,204)?;205output_tensors.push(NamedTensor {206name: output_feature.Name()?.to_string(),207tensor,208});209}210Ok(Some(output_tensors))211}212None => {213self.result = Some(self.session.Evaluate(&self.binding, &HSTRING::new())?);214Ok(None)215}216}217}218219fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {220if let Some(result) = &self.result {221let output_features = self.session.Model()?.OutputFeatures()?;222let index = self.find(id, &output_features)?;223let output_feature = output_features.GetAt(index)?;224let tensor_kind = match output_feature.Kind()? {225windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => output_feature226.cast::<TensorFeatureDescriptor>()?227.TensorKind()?,228_ => unimplemented!(229"the WinML backend only supports tensors, found: {:?}",230output_feature.Kind()231),232};233let tensor = to_tensor(234result.Outputs()?.Lookup(&output_feature.Name()?)?,235tensor_kind,236);237tensor238} else {239return Err(BackendError::BackendAccess(wasmtime::Error::msg(240"Output is not ready.",241)));242}243}244}245246/// Read a file into a byte vector.247fn read(path: &Path) -> wasmtime::Result<Vec<u8>> {248let mut file = File::open(path)?;249let mut buffer = vec![];250file.read_to_end(&mut buffer)?;251Ok(buffer)252}253254impl From<windows::core::Error> for BackendError {255fn from(e: windows::core::Error) -> Self {256BackendError::BackendAccess(wasmtime::Error::new(e))257}258}259260fn dimensions_as_u32(dimensions: &IVectorView<i64>) -> Result<Vec<u32>, BackendError> {261dimensions262.into_iter()263.map(|d| if d == -1 { Ok(1) } else { convert_i64(d) })264.collect()265}266267fn convert_i64(i: i64) -> Result<u32, BackendError> {268u32::try_from(i).map_err(|d| -> BackendError {269wasmtime::format_err!("unable to convert dimension to u32: {d}").into()270})271}272273// Convert from wasi-nn tensor to WinML tensor.274fn to_inspectable(tensor: &Tensor) -> Result<IInspectable, Error> {275let shape = IVectorView::<i64>::try_from(276tensor277.dimensions278.iter()279.map(|&x| x as i64)280.collect::<Vec<i64>>(),281)?;282match tensor.ty {283// f16 is not official supported by stable version of Rust. https://github.com/rust-lang/rust/issues/116909284// Therefore we create TensorFloat16Bit from f32 array. https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning/struct.TensorFloat16Bit.html#method.CreateFromArray285TensorType::Fp16 => unsafe {286let data = std::slice::from_raw_parts(287tensor.data.as_ptr().cast::<f32>(),288tensor.data.len() / size_of::<f32>(),289);290check_alignment::<f32>(data);291TensorFloat16Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()292},293TensorType::Fp32 => unsafe {294let data = std::slice::from_raw_parts(295tensor.data.as_ptr().cast::<f32>(),296tensor.data.len() / size_of::<f32>(),297);298check_alignment::<f32>(data);299TensorFloat::CreateFromArray(&shape, data)?.cast::<IInspectable>()300},301TensorType::I64 => unsafe {302let data = std::slice::from_raw_parts(303tensor.data.as_ptr().cast::<i64>(),304tensor.data.len() / size_of::<i64>(),305);306check_alignment::<i64>(data);307TensorInt64Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()308},309_ => unimplemented!(),310}311}312313// Convert from WinML tensor to wasi-nn tensor.314fn to_tensor(inspectable: IInspectable, tensor_kind: TensorKind) -> Result<Tensor, BackendError> {315let tensor = match tensor_kind {316TensorKind::Float16 => {317let output_tensor = inspectable.cast::<TensorFloat16Bit>()?;318let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;319let view = output_tensor.GetAsVectorView()?;320// TODO: Move to f16 when it's available in stable.321let data = view.into_iter().flat_map(f32::to_le_bytes).collect();322Tensor {323ty: TensorType::Fp16,324dimensions,325data,326}327}328TensorKind::Float => {329let output_tensor = inspectable.cast::<TensorFloat>()?;330let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;331let view = output_tensor.GetAsVectorView()?;332let data = view.into_iter().flat_map(f32::to_le_bytes).collect();333Tensor {334ty: TensorType::Fp32,335dimensions,336data,337}338}339TensorKind::Int64 => {340let output_tensor = inspectable.cast::<TensorInt64Bit>()?;341let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;342let view = output_tensor.GetAsVectorView()?;343let data = view.into_iter().flat_map(i64::to_le_bytes).collect();344Tensor {345ty: TensorType::I64,346dimensions,347data,348}349}350_ => unimplemented!(),351};352Ok(tensor)353}354355fn check_alignment<T>(data: &[T]) {356let (prefix, _slice, suffix) = unsafe { data.align_to::<T>() };357assert!(358prefix.is_empty() && suffix.is_empty(),359"Data is not aligned to {:?}'s alignment",360std::any::type_name::<T>()361);362}363364#[cfg(test)]365mod tests {366use super::*;367368// Unit tests for different data types. Convert from wasi-nn tensor to WinML tensor and back.369#[test]370fn fp16() {371let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];372let buffer = data373.iter()374.map(|f| f.to_ne_bytes())375.flatten()376.collect::<Vec<u8>>();377let buffer_copy = buffer.clone();378let tensor = Tensor {379ty: TensorType::Fp16,380dimensions: vec![2, 3],381data: buffer_copy,382};383let inspectable = to_inspectable(&tensor);384assert!(inspectable.is_ok());385let winml_tensor = inspectable386.as_ref()387.unwrap()388.cast::<TensorFloat16Bit>()389.unwrap();390let view = winml_tensor.GetAsVectorView().unwrap();391assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);392// Convert back.393let t = to_tensor(inspectable.unwrap(), TensorKind::Float16);394assert!(t.as_ref().is_ok());395assert_eq!(t.unwrap(), tensor);396}397398#[test]399fn fp32() {400let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];401let mut buffer = Vec::with_capacity(data.len() * size_of::<f32>());402for f in &data {403buffer.extend(f.to_ne_bytes());404}405let buffer_copy = buffer.clone();406let tensor = Tensor {407ty: TensorType::Fp32,408dimensions: vec![2, 3],409data: buffer_copy,410};411let inspectable = to_inspectable(&tensor);412assert!(inspectable.is_ok());413let winml_tensor = inspectable.as_ref().unwrap().cast::<TensorFloat>().unwrap();414let view = winml_tensor.GetAsVectorView().unwrap();415assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);416// Convert back.417let t = to_tensor(inspectable.unwrap(), TensorKind::Float);418assert!(t.as_ref().is_ok());419assert_eq!(t.unwrap(), tensor);420}421422#[test]423fn i64() {424let data = vec![6i64, 5, 4, 3, 2, 1];425let mut buffer = Vec::with_capacity(data.len() * size_of::<i64>());426for f in &data {427buffer.extend(f.to_ne_bytes());428}429let buffer_copy = buffer.clone();430let tensor = Tensor {431ty: TensorType::I64,432dimensions: vec![1, 6],433data: buffer_copy,434};435let inspectable = to_inspectable(&tensor);436assert!(inspectable.is_ok());437let winml_tensor = inspectable438.as_ref()439.unwrap()440.cast::<TensorInt64Bit>()441.unwrap();442let view = winml_tensor.GetAsVectorView().unwrap();443assert_eq!(view.into_iter().collect::<Vec<i64>>(), data);444// Convert back.445let t = to_tensor(inspectable.unwrap(), TensorKind::Int64);446assert!(t.as_ref().is_ok());447assert_eq!(t.unwrap(), tensor);448}449}450451452