Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/wit/wasi-nn.wit
1692 views
package wasi:[email protected];

/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet)
/// capable of performing ML training. WebAssembly programs that want to use a host's ML
/// capabilities can access these capabilities through `wasi-nn`'s core abstractions: _graphs_ and
/// _tensors_. A user `load`s an ML model -- instantiated as a _graph_ -- to use in an ML _backend_.
/// Then, the user passes _tensor_ inputs to the _graph_, computes the inference, and retrieves the
/// _tensor_ outputs.
///
/// This example world shows how to use these primitives together.
world ml {
    import tensor;
    import graph;
    import inference;
    import errors;
}

/// All inputs and outputs to an ML inference are represented as `tensor`s.
interface tensor {
    /// The dimensions of a tensor.
    ///
    /// The array length matches the tensor rank and each element in the array describes the size of
    /// each dimension
    type tensor-dimensions = list<u32>;

    /// The type of the elements in a tensor.
    enum tensor-type {
        FP16,
        FP32,
        FP64,
        BF16,
        U8,
        I32,
        I64
    }

    /// The tensor data.
    ///
    /// Initially conceived as a sparse representation, each empty cell would be filled with zeros
    /// and the array length must match the product of all of the dimensions and the number of bytes
    /// in the type (e.g., a 2x2 tensor with 4-byte f32 elements would have a data array of length
    /// 16). Naturally, this representation requires some knowledge of how to lay out data in
    /// memory--e.g., using row-major ordering--and could perhaps be improved.
    type tensor-data = list<u8>;

    resource tensor {
        constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data);

        // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
        // containing a single value, use `[1]` for the tensor dimensions.
        dimensions: func() -> tensor-dimensions;

        // Describe the type of element in the tensor (e.g., `f32`).
        ty: func() -> tensor-type;

        // Return the tensor data.
        data: func() -> tensor-data;
    }
}

/// A `graph` is a loaded instance of a specific ML model (e.g., MobileNet) for a specific ML
/// framework (e.g., TensorFlow):
interface graph {
    use errors.{error};
    use tensor.{tensor};
    use inference.{graph-execution-context};

    /// An execution graph for performing inference (i.e., a model).
    resource graph {
        init-execution-context: func() -> result<graph-execution-context, error>;
    }

    /// Describes the encoding of the graph. This allows the API to be implemented by various
    /// backends that encode (i.e., serialize) their graph IR with different formats.
    enum graph-encoding {
        openvino,
        onnx,
        tensorflow,
        pytorch,
        tensorflowlite,
        ggml,
        autodetect,
    }

    /// Define where the graph should be executed.
    enum execution-target {
        cpu,
        gpu,
        tpu
    }

    /// The graph initialization data.
    ///
    /// This gets bundled up into an array of buffers because implementing backends may encode their
    /// graph IR in parts (e.g., OpenVINO stores its IR and weights separately).
    type graph-builder = list<u8>;

    /// Load a `graph` from an opaque sequence of bytes to use for inference.
    load: func(builder: list<graph-builder>, encoding: graph-encoding, target: execution-target) -> result<graph, error>;

    /// Load a `graph` by name.
    ///
    /// How the host expects the names to be passed and how it stores the graphs for retrieval via
    /// this function is **implementation-specific**. This allows hosts to choose name schemes that
    /// range from simple to complex (e.g., URLs?) and caching mechanisms of various kinds.
    load-by-name: func(name: string) -> result<graph, error>;
}

/// An inference "session" is encapsulated by a `graph-execution-context`. This structure binds a
/// `graph` to input tensors before `compute`-ing an inference:
interface inference {
    use errors.{error};
    use tensor.{tensor};

    /// Identify a tensor by name; this is necessary to associate tensors to
    /// graph inputs and outputs.
    type named-tensor = tuple<string, tensor>;

    /// Bind a `graph` to the input and output tensors for an inference.
    ///
    /// TODO: this may no longer be necessary in WIT
    /// (https://github.com/WebAssembly/wasi-nn/issues/43)
    resource graph-execution-context {
        /// Compute the inference on the given inputs.
        compute: func(inputs: list<named-tensor>) -> result<list<named-tensor>, error>;
    }
}

/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
interface errors {
    enum error-code {
        // Caller module passed an invalid argument.
        invalid-argument,
        // Invalid encoding.
        invalid-encoding,
        // The operation timed out.
        timeout,
        // Runtime Error.
        runtime-error,
        // Unsupported operation.
        unsupported-operation,
        // Graph is too large.
        too-large,
        // Graph not found.
        not-found,
        // The operation is insecure or has insufficient privilege to be performed.
        // e.g., cannot access a hardware feature requested
        security,
        // The operation failed for an unspecified reason.
        unknown
    }

    resource error {
        /// Return the error code.
        code: func() -> error-code;

        /// Errors can propagated with backend specific status through a string value.
        data: func() -> string;
    }
}