use crate::{Backend, Registry};
use anyhow::anyhow;
use std::collections::HashMap;
use std::hash::Hash;
use std::{fmt, str::FromStr};
use wasmtime::component::{HasData, Resource, ResourceTable};
pub struct WasiNnCtx {
pub(crate) backends: HashMap<GraphEncoding, Backend>,
pub(crate) registry: Registry,
}
impl WasiNnCtx {
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
Self { backends, registry }
}
}
pub struct WasiNnView<'a> {
ctx: &'a mut WasiNnCtx,
table: &'a mut ResourceTable,
}
impl<'a> WasiNnView<'a> {
pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
Self { ctx, table }
}
}
#[derive(Debug)]
pub struct Error {
code: ErrorCode,
data: anyhow::Error,
}
macro_rules! bail {
($self:ident, $code:expr, $data:expr) => {
let e = Error {
code: $code,
data: $data.into(),
};
tracing::error!("failure: {e:?}");
let r = $self.table.push(e)?;
return Ok(Err(r));
};
}
impl From<wasmtime::component::ResourceTableError> for Error {
fn from(error: wasmtime::component::ResourceTableError) -> Self {
Self {
code: ErrorCode::Trap,
data: error.into(),
}
}
}
#[derive(Debug)]
pub enum ErrorCode {
InvalidArgument,
InvalidEncoding,
Timeout,
RuntimeError,
UnsupportedOperation,
TooLarge,
NotFound,
Trap,
}
pub(crate) mod generated_ {
wasmtime::component::bindgen!({
world: "ml",
path: "wit/wasi-nn.wit",
with: {
"wasi:nn/graph/graph": crate::Graph,
"wasi:nn/tensor/tensor": crate::Tensor,
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
"wasi:nn/errors/error": super::Error,
},
imports: { default: trappable },
trappable_error_type: {
"wasi:nn/errors/error" => super::Error,
},
});
}
use generated_::wasi::nn::{self as generated};
pub mod types {
use super::generated;
pub use generated::errors::Error;
pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
pub use generated::inference::GraphExecutionContext;
pub use generated::tensor::{Tensor, TensorType};
}
pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
pub use generated::inference::{GraphExecutionContext, NamedTensor};
pub use generated::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
pub use generated_::Ml as ML;
pub fn add_to_linker<T: 'static>(
l: &mut wasmtime::component::Linker<T>,
f: fn(&mut T) -> WasiNnView<'_>,
) -> anyhow::Result<()> {
generated::graph::add_to_linker::<_, HasWasiNnView>(l, f)?;
generated::tensor::add_to_linker::<_, HasWasiNnView>(l, f)?;
generated::inference::add_to_linker::<_, HasWasiNnView>(l, f)?;
generated::errors::add_to_linker::<_, HasWasiNnView>(l, f)?;
Ok(())
}
struct HasWasiNnView;
impl HasData for HasWasiNnView {
type Data<'a> = WasiNnView<'a>;
}
impl generated::graph::Host for WasiNnView<'_> {
fn load(
&mut self,
builders: Vec<GraphBuilder>,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
tracing::debug!("load {encoding:?} {target:?}");
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
match backend.load(&slices, target) {
Ok(graph) => {
let graph = self.table.push(graph)?;
Ok(Ok(graph))
}
Err(error) => {
bail!(self, ErrorCode::RuntimeError, error);
}
}
} else {
bail!(
self,
ErrorCode::InvalidEncoding,
anyhow!("unable to find a backend for this encoding")
);
}
}
fn load_by_name(
&mut self,
name: String,
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
use core::result::Result::*;
tracing::debug!("load by name {name:?}");
let registry = &self.ctx.registry;
if let Some(graph) = registry.get(&name) {
let graph = graph.clone();
let graph = self.table.push(graph)?;
Ok(Ok(graph))
} else {
bail!(
self,
ErrorCode::NotFound,
anyhow!("failed to find graph with name: {name}")
);
}
}
}
impl generated::graph::HostGraph for WasiNnView<'_> {
fn init_execution_context(
&mut self,
graph: Resource<Graph>,
) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
use core::result::Result::*;
tracing::debug!("initialize execution context");
let graph = self.table.get(&graph)?;
match graph.init_execution_context() {
Ok(exec_context) => {
let exec_context = self.table.push(exec_context)?;
Ok(Ok(exec_context))
}
Err(error) => {
bail!(self, ErrorCode::RuntimeError, error);
}
}
}
fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
self.table.delete(graph)?;
Ok(())
}
}
impl generated::inference::HostGraphExecutionContext for WasiNnView<'_> {
fn compute(
&mut self,
exec_context: Resource<GraphExecutionContext>,
inputs: Vec<NamedTensor>,
) -> wasmtime::Result<Result<Vec<NamedTensor>, Resource<Error>>> {
tracing::debug!("compute with {} inputs", inputs.len());
let mut named_tensors = Vec::new();
for (name, tensor_resopurce) in inputs.iter() {
let tensor = self.table.get(&tensor_resopurce)?;
named_tensors.push(crate::backend::NamedTensor {
name: name.clone(),
tensor: tensor.clone(),
});
}
let exec_context = &mut self.table.get_mut(&exec_context)?;
match exec_context.compute_with_io(named_tensors) {
Ok(named_tensors) => {
let result = named_tensors
.into_iter()
.map(|crate::backend::NamedTensor { name, tensor }| {
self.table.push(tensor).map(|resource| (name, resource))
})
.collect();
match result {
Ok(tuples) => Ok(Ok(tuples)),
Err(error) => {
bail!(self, ErrorCode::RuntimeError, error);
}
}
}
Err(error) => {
bail!(self, ErrorCode::RuntimeError, error);
}
}
}
fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
self.table.delete(exec_context)?;
Ok(())
}
}
impl generated::tensor::HostTensor for WasiNnView<'_> {
fn new(
&mut self,
dimensions: TensorDimensions,
ty: TensorType,
data: TensorData,
) -> wasmtime::Result<Resource<Tensor>> {
let tensor = Tensor {
dimensions,
ty,
data,
};
let tensor = self.table.push(tensor)?;
Ok(tensor)
}
fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.dimensions.clone())
}
fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.ty)
}
fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.data.clone())
}
fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
self.table.delete(tensor)?;
Ok(())
}
}
impl generated::errors::HostError for WasiNnView<'_> {
fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<generated::errors::ErrorCode> {
let error = self.table.get(&error)?;
match error.code {
ErrorCode::InvalidArgument => Ok(generated::errors::ErrorCode::InvalidArgument),
ErrorCode::InvalidEncoding => Ok(generated::errors::ErrorCode::InvalidEncoding),
ErrorCode::Timeout => Ok(generated::errors::ErrorCode::Timeout),
ErrorCode::RuntimeError => Ok(generated::errors::ErrorCode::RuntimeError),
ErrorCode::UnsupportedOperation => {
Ok(generated::errors::ErrorCode::UnsupportedOperation)
}
ErrorCode::TooLarge => Ok(generated::errors::ErrorCode::TooLarge),
ErrorCode::NotFound => Ok(generated::errors::ErrorCode::NotFound),
ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
}
}
fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
let error = self.table.get(&error)?;
Ok(error.data.to_string())
}
fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
self.table.delete(error)?;
Ok(())
}
}
impl generated::errors::Host for WasiNnView<'_> {
fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
if matches!(err.code, ErrorCode::Trap) {
Err(err.data)
} else {
Ok(err)
}
}
}
impl generated::tensor::Host for WasiNnView<'_> {}
impl generated::inference::Host for WasiNnView<'_> {}
impl Hash for generated::graph::GraphEncoding {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.to_string().hash(state)
}
}
impl fmt::Display for generated::graph::GraphEncoding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use generated::graph::GraphEncoding::*;
match self {
Openvino => write!(f, "openvino"),
Onnx => write!(f, "onnx"),
Pytorch => write!(f, "pytorch"),
Tensorflow => write!(f, "tensorflow"),
Tensorflowlite => write!(f, "tensorflowlite"),
Autodetect => write!(f, "autodetect"),
Ggml => write!(f, "ggml"),
}
}
}
impl FromStr for generated::graph::GraphEncoding {
type Err = GraphEncodingParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openvino" => Ok(generated::graph::GraphEncoding::Openvino),
"onnx" => Ok(generated::graph::GraphEncoding::Onnx),
"pytorch" => Ok(generated::graph::GraphEncoding::Pytorch),
"tensorflow" => Ok(generated::graph::GraphEncoding::Tensorflow),
"tensorflowlite" => Ok(generated::graph::GraphEncoding::Tensorflowlite),
"autodetect" => Ok(generated::graph::GraphEncoding::Autodetect),
_ => Err(GraphEncodingParseError(s.into())),
}
}
}
#[derive(Debug)]
pub struct GraphEncodingParseError(String);
impl fmt::Display for GraphEncodingParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unknown graph encoding: {}", self.0)
}
}
impl std::error::Error for GraphEncodingParseError {}