Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/src/backend/pytorch.rs
2459 views
1
//! Implements a `wasi-nn` [`BackendInner`] using PyTorch.
2
//!
3
use super::{
4
BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
5
NamedTensor,
6
};
7
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
8
use crate::{ExecutionContext, Graph};
9
use std::path::Path;
10
use std::sync::{Arc, Mutex};
11
use tch::{CModule, Device, Kind, TchError, Tensor as TchTensor};
12
13
#[derive(Default)]
14
pub struct PytorchBackend();
15
unsafe impl Send for PytorchBackend {}
16
unsafe impl Sync for PytorchBackend {}
17
18
impl BackendInner for PytorchBackend {
19
fn encoding(&self) -> GraphEncoding {
20
GraphEncoding::Pytorch
21
}
22
23
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
24
if builders.len() != 1 {
25
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into());
26
}
27
// Load the torchscript saved module.
28
let mut saved_module = builders[0];
29
30
// Load the saved model on the device.
31
let mut compiled_module = CModule::load_data_on_device(
32
&mut saved_module,
33
map_execution_target_to_string(target),
34
)?;
35
36
// Set the model to be used for inference (eval), default mode is training.
37
compiled_module.f_set_eval()?;
38
39
let graph = PytorchGraph {
40
module: Arc::new(Mutex::new(compiled_module)),
41
target,
42
};
43
let box_: Box<dyn BackendGraph> = Box::new(graph);
44
Ok(box_.into())
45
}
46
47
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {
48
Some(self)
49
}
50
}
51
52
impl BackendFromDir for PytorchBackend {
53
fn load_from_dir(
54
&mut self,
55
path: &Path,
56
target: ExecutionTarget,
57
) -> Result<Graph, BackendError> {
58
// Load the model from the file path.
59
let compiled_module = CModule::load_on_device(
60
path.join("model.pt"),
61
map_execution_target_to_string(target),
62
)?;
63
let graph = PytorchGraph {
64
module: Arc::new(Mutex::new(compiled_module)),
65
target,
66
};
67
let box_: Box<dyn BackendGraph> = Box::new(graph);
68
Ok(box_.into())
69
}
70
}
71
72
struct PytorchGraph {
73
module: Arc<Mutex<tch::CModule>>,
74
target: ExecutionTarget,
75
}
76
77
unsafe impl Send for PytorchGraph {}
78
unsafe impl Sync for PytorchGraph {}
79
80
impl BackendGraph for PytorchGraph {
81
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
82
let box_: Box<dyn BackendExecutionContext> = Box::new(PytorchExecutionContext {
83
module: self.module.clone(),
84
inputs: Vec::new(),
85
output: TchTensor::new(),
86
id_type: None,
87
target: self.target,
88
});
89
90
Ok(box_.into())
91
}
92
}
93
94
unsafe impl Sync for PytorchExecutionContext {}
95
struct PytorchExecutionContext {
96
module: Arc<Mutex<tch::CModule>>,
97
inputs: Vec<Option<tch::Tensor>>,
98
output: tch::Tensor,
99
id_type: Option<Id>,
100
target: ExecutionTarget,
101
}
102
103
/// `set_input` supports multiple positional parameters with `Id::Index`, and a single named parameter with `Id::Name`.
104
/// `set_input` may be removed in the future, with `compute` method taking a list of named parameters.
105
/// See [PR #77](https://github.com/WebAssembly/wasi-nn/pull/77), at which point multiple named parameters for `Tensor` inputs is planned to be supported in pytorch backend.
106
impl BackendExecutionContext for PytorchExecutionContext {
107
fn set_input(&mut self, id: Id, input_tensor: &Tensor) -> Result<(), BackendError> {
108
let kind = input_tensor.ty.try_into()?;
109
let dimensions = input_tensor
110
.dimensions
111
.iter()
112
.map(|&dim| dim as i64)
113
.collect::<Vec<_>>();
114
let tensor = TchTensor::from_data_size(&input_tensor.data, &dimensions, kind)
115
.to_device(map_execution_target_to_string(self.target));
116
match id {
117
Id::Index(i) => {
118
// Check if id_type is already set and if it matches the current id type
119
if let Some(Id::Name(_)) = self.id_type {
120
return Err(BackendError::BackendAccess(anyhow::anyhow!(
121
"Cannot mix u32 and str indexes"
122
)));
123
}
124
// Set id_type if not already set
125
if self.id_type.is_none() {
126
self.id_type = Some(Id::Index(0)); // Provide a u32 value for Index
127
}
128
let i = i as usize;
129
if i >= self.inputs.len() {
130
self.inputs.resize_with(i + 1, || None);
131
}
132
self.inputs[i] = Some(tensor);
133
Ok(())
134
}
135
Id::Name(_) => {
136
// Check if id_type is already set and if it matches the current id type
137
if let Some(Id::Index(_)) = self.id_type {
138
return Err(BackendError::BackendAccess(anyhow::anyhow!(
139
"Cannot mix u32 and str indexes"
140
)));
141
}
142
// Set id_type if not already set
143
if self.id_type.is_none() {
144
self.id_type = Some(Id::Name(String::new())); // Provide a str value for Name
145
}
146
if self.inputs.get(0).is_some() {
147
return Err(BackendError::BackendAccess(anyhow::anyhow!(
148
"The pytorch backend does not support multiple named inputs"
149
)));
150
} else {
151
self.inputs.push(Some(tensor));
152
}
153
Ok(())
154
}
155
}
156
}
157
158
fn compute(
159
&mut self,
160
inputs: Option<Vec<NamedTensor>>,
161
) -> Result<Option<Vec<NamedTensor>>, BackendError> {
162
match inputs {
163
// WIT-style compute with named tensors
164
Some(inputs) => {
165
self.inputs.clear();
166
self.id_type = None;
167
for input in &inputs {
168
let kind = input.tensor.ty.try_into()?;
169
let dimensions = input
170
.tensor
171
.dimensions
172
.iter()
173
.map(|&dim| dim as i64)
174
.collect::<Vec<_>>();
175
176
let tensor = TchTensor::from_data_size(&input.tensor.data, &dimensions, kind)
177
.to_device(map_execution_target_to_string(self.target));
178
self.inputs.push(Some(tensor));
179
180
// Set ID type to Name since we're using named tensors
181
if self.id_type.is_none() {
182
self.id_type = Some(Id::Name(String::new()));
183
}
184
}
185
// Run the forward pass
186
let inputs: Vec<tch::Tensor> = self
187
.inputs
188
.iter()
189
.enumerate()
190
.map(|(index, opt)| {
191
opt.as_ref()
192
.expect(&format!("Input tensor at index {} not set up", index))
193
.shallow_clone()
194
})
195
.collect();
196
self.output = self.module.lock().unwrap().forward_ts(&inputs)?;
197
let numel = self.output.numel();
198
let dimensions = self.output.size();
199
let ty = self.output.kind().try_into()?;
200
let mut data = vec![0u8; kind_to_size(self.output.kind())? * numel];
201
self.output.copy_data_u8(&mut data, numel);
202
let output_tensor = Tensor {
203
dimensions: dimensions.iter().map(|&dim| dim as u32).collect(),
204
ty,
205
data,
206
};
207
let output = NamedTensor {
208
name: "output".to_string(),
209
tensor: output_tensor,
210
};
211
Ok(Some(vec![output]))
212
}
213
214
// WITX-style compute with previously set inputs
215
None => {
216
if self.inputs.is_empty() {
217
return Err(BackendError::BackendAccess(anyhow::anyhow!(
218
"No inputs provided for inference"
219
)));
220
}
221
let inputs: Vec<tch::Tensor> = self
222
.inputs
223
.iter()
224
.enumerate()
225
.map(|(index, opt)| {
226
opt.as_ref()
227
.expect(&format!("Input tensor at index {} not set up", index))
228
.shallow_clone()
229
})
230
.collect();
231
// Perform forward pass
232
self.output = self.module.lock().unwrap().forward_ts(&inputs)?;
233
Ok(None)
234
}
235
}
236
}
237
238
fn get_output(&mut self, _index: Id) -> Result<Tensor, BackendError> {
239
// Output index is not used. The forward_ts method to a model returns a single output tensor.
240
let numel = self.output.numel();
241
let dimensions = self.output.size();
242
let ty = self.output.kind().try_into()?;
243
let mut data = vec![0u8; kind_to_size(self.output.kind())? * numel];
244
self.output.copy_data_u8(&mut data, numel);
245
Ok(Tensor {
246
dimensions: dimensions.iter().map(|&dim| dim as u32).collect(),
247
ty,
248
data,
249
})
250
}
251
}
252
253
fn map_execution_target_to_string(target: ExecutionTarget) -> Device {
254
match target {
255
ExecutionTarget::Cpu => Device::Cpu,
256
ExecutionTarget::Gpu => Device::Cuda(0),
257
ExecutionTarget::Tpu => {
258
unimplemented!("the pytorch backend does not yet support TPU execution targets")
259
}
260
}
261
}
262
263
fn kind_to_size(kind: Kind) -> Result<usize, BackendError> {
264
match kind {
265
Kind::Float | Kind::Half => Ok(std::mem::size_of::<f32>()), // f16 is unstable https://github.com/rust-lang/rust/issues/116909
266
Kind::Double => Ok(std::mem::size_of::<f64>()),
267
Kind::Int => Ok(std::mem::size_of::<i32>()),
268
Kind::Uint8 => Ok(std::mem::size_of::<u8>()),
269
Kind::Int64 => Ok(std::mem::size_of::<i64>()),
270
_ => Err(BackendError::UnsupportedTensorType(format!("{:?}", kind))),
271
}
272
}
273
274
/// Returns the PyTorch [`Kind`] from wasi-nn's [`TensorType`].
275
impl TryFrom<TensorType> for Kind {
276
type Error = BackendError;
277
278
fn try_from(tensor_type: TensorType) -> Result<Self, Self::Error> {
279
match tensor_type {
280
TensorType::Fp16 => Ok(Kind::Half),
281
TensorType::Fp32 => Ok(Kind::Float),
282
TensorType::Fp64 => Ok(Kind::Double),
283
TensorType::U8 => Ok(Kind::Uint8),
284
TensorType::I32 => Ok(Kind::Int),
285
TensorType::I64 => Ok(Kind::Int64),
286
_ => Err(BackendError::UnsupportedTensorType(format!(
287
"{:?}",
288
tensor_type
289
))),
290
}
291
}
292
}
293
294
/// Returns wasi-nn [`TensorType`] from PyTorch's [`Kind`].
295
impl TryFrom<Kind> for TensorType {
296
type Error = BackendError;
297
298
fn try_from(kind: Kind) -> Result<Self, Self::Error> {
299
match kind {
300
Kind::Half => Ok(TensorType::Fp16),
301
Kind::Float => Ok(TensorType::Fp32),
302
Kind::Double => Ok(TensorType::Fp64),
303
Kind::Uint8 => Ok(TensorType::U8),
304
Kind::Int => Ok(TensorType::I32),
305
Kind::Int64 => Ok(TensorType::I64),
306
_ => Err(BackendError::UnsupportedTensorType(format!("{:?}", kind))),
307
}
308
}
309
}
310
311
impl From<TchError> for BackendError {
312
fn from(e: TchError) -> Self {
313
BackendError::BackendAccess(anyhow::Error::new(e))
314
}
315
}
316
317