pub mod backend;
mod registry;
pub mod wit;
pub mod witx;
use crate::backend::{BackendError, Id, NamedTensor as BackendNamedTensor};
use crate::wit::generated_::wasi::nn::tensor::TensorType;
use anyhow::anyhow;
use core::fmt;
pub use registry::{GraphRegistry, InMemoryRegistry};
use std::path::Path;
use std::sync::Arc;
pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec<Backend>, Registry)> {
let mut backends = backend::list();
let mut registry = InMemoryRegistry::new();
for (kind, path) in preload_graphs {
let kind_ = kind.parse()?;
let backend = backends
.iter_mut()
.find(|b| b.encoding() == kind_)
.ok_or(anyhow!("unsupported backend: {}", kind))?
.as_dir_loadable()
.ok_or(anyhow!("{} does not support directory loading", kind))?;
registry.load(backend, Path::new(path))?;
}
Ok((backends, Registry::from(registry)))
}
pub struct Backend(Box<dyn backend::BackendInner>);
impl std::ops::Deref for Backend {
type Target = dyn backend::BackendInner;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for Backend {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}
impl<T: backend::BackendInner + 'static> From<T> for Backend {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
#[derive(Clone)]
pub struct Graph(Arc<dyn backend::BackendGraph>);
impl From<Box<dyn backend::BackendGraph>> for Graph {
fn from(value: Box<dyn backend::BackendGraph>) -> Self {
Self(value.into())
}
}
impl std::ops::Deref for Graph {
type Target = dyn backend::BackendGraph;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
#[derive(Clone, PartialEq)]
pub struct Tensor {
pub dimensions: Vec<u32>,
pub ty: TensorType,
pub data: Vec<u8>,
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tensor")
.field("dimensions", &self.dimensions)
.field("ty", &self.ty)
.field("data (bytes)", &self.data.len())
.finish()
}
}
pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {
fn from(value: Box<dyn backend::BackendExecutionContext>) -> Self {
Self(value)
}
}
impl std::ops::Deref for ExecutionContext {
type Target = dyn backend::BackendExecutionContext;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for ExecutionContext {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}
pub struct Registry(Box<dyn GraphRegistry>);
impl std::ops::Deref for Registry {
type Target = dyn GraphRegistry;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::ops::DerefMut for Registry {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}
impl<T> From<T> for Registry
where
T: GraphRegistry + 'static,
{
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
impl ExecutionContext {
pub fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
self.0.set_input(id, tensor)
}
pub fn compute(&mut self) -> Result<(), BackendError> {
self.0.compute(None).map(|_| ())
}
pub fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
self.0.get_output(id)
}
pub fn compute_with_io(
&mut self,
inputs: Vec<BackendNamedTensor>,
) -> Result<Vec<BackendNamedTensor>, BackendError> {
match self.0.compute(Some(inputs))? {
Some(outputs) => Ok(outputs),
None => Ok(Vec::new()),
}
}
}
impl Tensor {
pub fn new(dimensions: Vec<u32>, ty: TensorType, data: Vec<u8>) -> Self {
Self {
dimensions,
ty,
data,
}
}
}