Path: blob/main/crates/wasi-nn/src/backend/pytorch.rs
2459 views
//! Implements a `wasi-nn` [`BackendInner`] using PyTorch.1//!2use super::{3BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,4NamedTensor,5};6use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};7use crate::{ExecutionContext, Graph};8use std::path::Path;9use std::sync::{Arc, Mutex};10use tch::{CModule, Device, Kind, TchError, Tensor as TchTensor};1112#[derive(Default)]13pub struct PytorchBackend();14unsafe impl Send for PytorchBackend {}15unsafe impl Sync for PytorchBackend {}1617impl BackendInner for PytorchBackend {18fn encoding(&self) -> GraphEncoding {19GraphEncoding::Pytorch20}2122fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {23if builders.len() != 1 {24return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into());25}26// Load the torchscript saved module.27let mut saved_module = builders[0];2829// Load the saved model on the device.30let mut compiled_module = CModule::load_data_on_device(31&mut saved_module,32map_execution_target_to_string(target),33)?;3435// Set the model to be used for inference (eval), default mode is training.36compiled_module.f_set_eval()?;3738let graph = PytorchGraph {39module: Arc::new(Mutex::new(compiled_module)),40target,41};42let box_: Box<dyn BackendGraph> = Box::new(graph);43Ok(box_.into())44}4546fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {47Some(self)48}49}5051impl BackendFromDir for PytorchBackend {52fn load_from_dir(53&mut self,54path: &Path,55target: ExecutionTarget,56) -> Result<Graph, BackendError> {57// Load the model from the file path.58let compiled_module = CModule::load_on_device(59path.join("model.pt"),60map_execution_target_to_string(target),61)?;62let graph = PytorchGraph {63module: Arc::new(Mutex::new(compiled_module)),64target,65};66let box_: Box<dyn BackendGraph> = Box::new(graph);67Ok(box_.into())68}69}7071struct PytorchGraph {72module: Arc<Mutex<tch::CModule>>,73target: ExecutionTarget,74}7576unsafe impl Send for PytorchGraph {}77unsafe impl Sync for PytorchGraph {}7879impl BackendGraph for PytorchGraph {80fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {81let box_: Box<dyn BackendExecutionContext> = Box::new(PytorchExecutionContext {82module: self.module.clone(),83inputs: Vec::new(),84output: TchTensor::new(),85id_type: None,86target: self.target,87});8889Ok(box_.into())90}91}9293unsafe impl Sync for PytorchExecutionContext {}94struct PytorchExecutionContext {95module: Arc<Mutex<tch::CModule>>,96inputs: Vec<Option<tch::Tensor>>,97output: tch::Tensor,98id_type: Option<Id>,99target: ExecutionTarget,100}101102/// `set_input` supports multiple positional parameters with `Id::Index`, and a single named parameter with `Id::Name`.103/// `set_input` may be removed in the future, with `compute` method taking a list of named parameters.104/// See [PR #77](https://github.com/WebAssembly/wasi-nn/pull/77), at which point multiple named parameters for `Tensor` inputs is planned to be supported in pytorch backend.105impl BackendExecutionContext for PytorchExecutionContext {106fn set_input(&mut self, id: Id, input_tensor: &Tensor) -> Result<(), BackendError> {107let kind = input_tensor.ty.try_into()?;108let dimensions = input_tensor109.dimensions110.iter()111.map(|&dim| dim as i64)112.collect::<Vec<_>>();113let tensor = TchTensor::from_data_size(&input_tensor.data, &dimensions, kind)114.to_device(map_execution_target_to_string(self.target));115match id {116Id::Index(i) => {117// Check if id_type is already set and if it matches the current id type118if let Some(Id::Name(_)) = self.id_type {119return Err(BackendError::BackendAccess(anyhow::anyhow!(120"Cannot mix u32 and str indexes"121)));122}123// Set id_type if not already set124if self.id_type.is_none() {125self.id_type = Some(Id::Index(0)); // Provide a u32 value for Index126}127let i = i as usize;128if i >= self.inputs.len() {129self.inputs.resize_with(i + 1, || None);130}131self.inputs[i] = Some(tensor);132Ok(())133}134Id::Name(_) => {135// Check if id_type is already set and if it matches the current id type136if let Some(Id::Index(_)) = self.id_type {137return Err(BackendError::BackendAccess(anyhow::anyhow!(138"Cannot mix u32 and str indexes"139)));140}141// Set id_type if not already set142if self.id_type.is_none() {143self.id_type = Some(Id::Name(String::new())); // Provide a str value for Name144}145if self.inputs.get(0).is_some() {146return Err(BackendError::BackendAccess(anyhow::anyhow!(147"The pytorch backend does not support multiple named inputs"148)));149} else {150self.inputs.push(Some(tensor));151}152Ok(())153}154}155}156157fn compute(158&mut self,159inputs: Option<Vec<NamedTensor>>,160) -> Result<Option<Vec<NamedTensor>>, BackendError> {161match inputs {162// WIT-style compute with named tensors163Some(inputs) => {164self.inputs.clear();165self.id_type = None;166for input in &inputs {167let kind = input.tensor.ty.try_into()?;168let dimensions = input169.tensor170.dimensions171.iter()172.map(|&dim| dim as i64)173.collect::<Vec<_>>();174175let tensor = TchTensor::from_data_size(&input.tensor.data, &dimensions, kind)176.to_device(map_execution_target_to_string(self.target));177self.inputs.push(Some(tensor));178179// Set ID type to Name since we're using named tensors180if self.id_type.is_none() {181self.id_type = Some(Id::Name(String::new()));182}183}184// Run the forward pass185let inputs: Vec<tch::Tensor> = self186.inputs187.iter()188.enumerate()189.map(|(index, opt)| {190opt.as_ref()191.expect(&format!("Input tensor at index {} not set up", index))192.shallow_clone()193})194.collect();195self.output = self.module.lock().unwrap().forward_ts(&inputs)?;196let numel = self.output.numel();197let dimensions = self.output.size();198let ty = self.output.kind().try_into()?;199let mut data = vec![0u8; kind_to_size(self.output.kind())? * numel];200self.output.copy_data_u8(&mut data, numel);201let output_tensor = Tensor {202dimensions: dimensions.iter().map(|&dim| dim as u32).collect(),203ty,204data,205};206let output = NamedTensor {207name: "output".to_string(),208tensor: output_tensor,209};210Ok(Some(vec![output]))211}212213// WITX-style compute with previously set inputs214None => {215if self.inputs.is_empty() {216return Err(BackendError::BackendAccess(anyhow::anyhow!(217"No inputs provided for inference"218)));219}220let inputs: Vec<tch::Tensor> = self221.inputs222.iter()223.enumerate()224.map(|(index, opt)| {225opt.as_ref()226.expect(&format!("Input tensor at index {} not set up", index))227.shallow_clone()228})229.collect();230// Perform forward pass231self.output = self.module.lock().unwrap().forward_ts(&inputs)?;232Ok(None)233}234}235}236237fn get_output(&mut self, _index: Id) -> Result<Tensor, BackendError> {238// Output index is not used. The forward_ts method to a model returns a single output tensor.239let numel = self.output.numel();240let dimensions = self.output.size();241let ty = self.output.kind().try_into()?;242let mut data = vec![0u8; kind_to_size(self.output.kind())? * numel];243self.output.copy_data_u8(&mut data, numel);244Ok(Tensor {245dimensions: dimensions.iter().map(|&dim| dim as u32).collect(),246ty,247data,248})249}250}251252fn map_execution_target_to_string(target: ExecutionTarget) -> Device {253match target {254ExecutionTarget::Cpu => Device::Cpu,255ExecutionTarget::Gpu => Device::Cuda(0),256ExecutionTarget::Tpu => {257unimplemented!("the pytorch backend does not yet support TPU execution targets")258}259}260}261262fn kind_to_size(kind: Kind) -> Result<usize, BackendError> {263match kind {264Kind::Float | Kind::Half => Ok(std::mem::size_of::<f32>()), // f16 is unstable https://github.com/rust-lang/rust/issues/116909265Kind::Double => Ok(std::mem::size_of::<f64>()),266Kind::Int => Ok(std::mem::size_of::<i32>()),267Kind::Uint8 => Ok(std::mem::size_of::<u8>()),268Kind::Int64 => Ok(std::mem::size_of::<i64>()),269_ => Err(BackendError::UnsupportedTensorType(format!("{:?}", kind))),270}271}272273/// Returns the PyTorch [`Kind`] from wasi-nn's [`TensorType`].274impl TryFrom<TensorType> for Kind {275type Error = BackendError;276277fn try_from(tensor_type: TensorType) -> Result<Self, Self::Error> {278match tensor_type {279TensorType::Fp16 => Ok(Kind::Half),280TensorType::Fp32 => Ok(Kind::Float),281TensorType::Fp64 => Ok(Kind::Double),282TensorType::U8 => Ok(Kind::Uint8),283TensorType::I32 => Ok(Kind::Int),284TensorType::I64 => Ok(Kind::Int64),285_ => Err(BackendError::UnsupportedTensorType(format!(286"{:?}",287tensor_type288))),289}290}291}292293/// Returns wasi-nn [`TensorType`] from PyTorch's [`Kind`].294impl TryFrom<Kind> for TensorType {295type Error = BackendError;296297fn try_from(kind: Kind) -> Result<Self, Self::Error> {298match kind {299Kind::Half => Ok(TensorType::Fp16),300Kind::Float => Ok(TensorType::Fp32),301Kind::Double => Ok(TensorType::Fp64),302Kind::Uint8 => Ok(TensorType::U8),303Kind::Int => Ok(TensorType::I32),304Kind::Int64 => Ok(TensorType::I64),305_ => Err(BackendError::UnsupportedTensorType(format!("{:?}", kind))),306}307}308}309310impl From<TchError> for BackendError {311fn from(e: TchError) -> Self {312BackendError::BackendAccess(anyhow::Error::new(e))313}314}315316317