Path: blob/main/crates/wasi-nn/tests/test-programs.rs
1691 views
//! Run the wasi-nn tests in `crates/test-programs`.1//!2//! It may be difficult to run to run all tests on all platforms; we check the3//! pre-requisites for each test dynamically (see [`check`]). Using4//! `libtest-mimic` allows us then to dynamically ignore tests that cannot run5//! on the current machine.6//!7//! There are two modes these tests run in:8//! - "ignore if unavailable" mode: if the checks for a test fail (e.g., the9//! backend is not installed, test artifacts cannot download, we're on the10//! wrong platform), the test is ignored.11//! - "fail if unavailable" mode: when the `CI` or `FORCE_WASINN_TEST_CHECK`12//! environment variables are set, any checks that fail cause the test to fail13//! early.1415mod check;16mod exec;1718use anyhow::Result;19use libtest_mimic::{Arguments, Trial};20use std::{borrow::Cow, env};21use test_programs_artifacts::*;22use wasmtime_wasi_nn::{Backend, backend};2324fn main() -> Result<()> {25tracing_subscriber::fmt::init();2627if cfg!(miri) {28return Ok(());29}3031// Gather a list of the test-program names.32let mut programs = Vec::new();33macro_rules! add_to_list {34($name:ident) => {35programs.push(stringify!($name));36};37}38foreach_nn!(add_to_list);3940// Make ignored tests turn into failures.41let error_on_failed_check =42env::var_os("CI").is_some() || env::var_os("FORCE_WASINN_TEST_CHECK").is_some();4344// Inform `libtest-mimic` how to run each test program.45let arguments = Arguments::from_args();46let mut trials = Vec::new();47for program in programs {48// Either ignore the test if it cannot run (i.e., downgrade `Fail` to49// `Ignore`) or preemptively fail it if `error_on_failed_check` is set.50let (run_test, mut check) = check_test_program(program);51if !error_on_failed_check {52check = check.downgrade_failure(); // Downgrade `Fail` to `Ignore`.53}54let should_ignore = check.is_ignore();55if arguments.nocapture && should_ignore {56println!("> ignoring {program}: {}", check.reason());57}58let trial = Trial::test(program, move || {59run_test().map_err(|e| format!("{e:?}").into())60})61.with_ignored_flag(should_ignore);62trials.push(trial);63}6465// Run the tests.66libtest_mimic::run(&arguments, trials).exit()67}6869/// Return the test program to run and a check that must pass for the test to70/// run.71fn check_test_program(name: &str) -> (fn() -> Result<()>, IgnoreCheck) {72match name {73// Legacy WITX-based tests:74"nn_witx_image_classification_openvino" => (75nn_witx_image_classification_openvino,76IgnoreCheck::for_openvino(),77),78"nn_witx_image_classification_openvino_named" => (79nn_witx_image_classification_openvino_named,80IgnoreCheck::for_openvino(),81),82"nn_witx_image_classification_onnx" => {83(nn_witx_image_classification_onnx, IgnoreCheck::for_onnx())84}85"nn_witx_image_classification_winml_named" => (86nn_witx_image_classification_winml_named,87IgnoreCheck::for_winml(),88),89"nn_witx_image_classification_pytorch" => (90nn_witx_image_classification_pytorch,91IgnoreCheck::for_pytorch(),92),93// WIT-based tests:94"nn_wit_image_classification_openvino" => (95nn_wit_image_classification_openvino,96IgnoreCheck::for_openvino(),97),98"nn_wit_image_classification_openvino_named" => (99nn_wit_image_classification_openvino_named,100IgnoreCheck::for_openvino(),101),102"nn_wit_image_classification_onnx" => {103(nn_wit_image_classification_onnx, IgnoreCheck::for_onnx())104}105"nn_wit_image_classification_winml_named" => (106nn_wit_image_classification_winml_named,107IgnoreCheck::for_winml(),108),109"nn_wit_image_classification_pytorch" => (110nn_wit_image_classification_pytorch,111IgnoreCheck::for_pytorch(),112),113_ => panic!("unknown test program: {name} (add to this `match`)"),114}115}116117fn nn_witx_image_classification_openvino() -> Result<()> {118check::openvino::is_installed()?;119check::openvino::are_artifacts_available()?;120let backend = Backend::from(backend::openvino::OpenvinoBackend::default());121exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO, backend, false)122}123124fn nn_witx_image_classification_openvino_named() -> Result<()> {125check::openvino::is_installed()?;126check::openvino::are_artifacts_available()?;127let backend = Backend::from(backend::openvino::OpenvinoBackend::default());128exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO_NAMED, backend, true)129}130131#[cfg(feature = "onnx")]132fn nn_witx_image_classification_onnx() -> Result<()> {133check::onnx::are_artifacts_available()?;134let backend = Backend::from(backend::onnx::OnnxBackend::default());135exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false)136}137#[cfg(not(feature = "onnx"))]138fn nn_witx_image_classification_onnx() -> Result<()> {139anyhow::bail!("this test requires the `onnx` feature")140}141142#[cfg(all(feature = "winml", target_os = "windows"))]143fn nn_witx_image_classification_winml_named() -> Result<()> {144check::winml::is_available()?;145check::onnx::are_artifacts_available()?;146let backend = Backend::from(backend::winml::WinMLBackend::default());147exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false)148}149#[cfg(not(all(feature = "winml", target_os = "windows")))]150fn nn_witx_image_classification_winml_named() -> Result<()> {151anyhow::bail!("this test requires the `winml` feature and only runs on windows")152}153154#[cfg(feature = "pytorch")]155fn nn_witx_image_classification_pytorch() -> Result<()> {156check::pytorch::are_artifacts_available()?;157let backend = Backend::from(backend::pytorch::PytorchBackend::default());158exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_PYTORCH, backend, false)159}160#[cfg(not(feature = "pytorch"))]161fn nn_witx_image_classification_pytorch() -> Result<()> {162anyhow::bail!("this test requires the `pytorch` feature")163}164165fn nn_wit_image_classification_openvino() -> Result<()> {166check::openvino::is_installed()?;167check::openvino::are_artifacts_available()?;168let backend = Backend::from(backend::openvino::OpenvinoBackend::default());169exec::wit::run(170NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_COMPONENT,171backend,172false,173)174}175176fn nn_wit_image_classification_openvino_named() -> Result<()> {177check::openvino::is_installed()?;178check::openvino::are_artifacts_available()?;179let backend = Backend::from(backend::openvino::OpenvinoBackend::default());180exec::wit::run(181NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_NAMED_COMPONENT,182backend,183true,184)185}186187#[cfg(feature = "onnx")]188fn nn_wit_image_classification_onnx() -> Result<()> {189check::onnx::are_artifacts_available()?;190let backend = Backend::from(backend::onnx::OnnxBackend::default());191exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false)192}193#[cfg(not(feature = "onnx"))]194fn nn_wit_image_classification_onnx() -> Result<()> {195anyhow::bail!("this test requires the `onnx` feature")196}197198#[cfg(feature = "pytorch")]199fn nn_wit_image_classification_pytorch() -> Result<()> {200check::pytorch::are_artifacts_available()?;201let backend = Backend::from(backend::pytorch::PytorchBackend::default());202exec::wit::run(203NN_WIT_IMAGE_CLASSIFICATION_PYTORCH_COMPONENT,204backend,205false,206)207}208#[cfg(not(feature = "pytorch"))]209fn nn_wit_image_classification_pytorch() -> Result<()> {210anyhow::bail!("this test requires the `pytorch` feature")211}212213#[cfg(all(feature = "winml", target_os = "windows"))]214fn nn_wit_image_classification_winml_named() -> Result<()> {215check::winml::is_available()?;216check::onnx::are_artifacts_available()?;217let backend = Backend::from(backend::winml::WinMLBackend::default());218exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false)219}220#[cfg(not(all(feature = "winml", target_os = "windows")))]221fn nn_wit_image_classification_winml_named() -> Result<()> {222anyhow::bail!("this test requires the `winml` feature and only runs on windows")223}224225/// Helper for keeping track of what tests should do when pre-test checks fail.226#[derive(Clone)]227enum IgnoreCheck {228Run,229Ignore(Cow<'static, str>),230Fail(Cow<'static, str>),231}232233impl IgnoreCheck {234fn reason(&self) -> &str {235match self {236IgnoreCheck::Run => panic!("cannot get reason for `Run`"),237IgnoreCheck::Ignore(reason) => reason,238IgnoreCheck::Fail(reason) => reason,239}240}241242fn downgrade_failure(self) -> Self {243if let IgnoreCheck::Fail(reason) = self {244IgnoreCheck::Ignore(reason)245} else {246self247}248}249250fn is_ignore(&self) -> bool {251matches!(self, IgnoreCheck::Ignore(_))252}253}254255/// Some pre-test checks for various backends.256impl IgnoreCheck {257fn for_openvino() -> IgnoreCheck {258use IgnoreCheck::*;259if !cfg!(target_arch = "x86_64") {260Fail("requires x86_64".into())261} else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") {262Fail("requires linux or windows or macos".into())263} else if let Err(e) = check::openvino::is_installed() {264Fail(e.to_string().into())265} else {266Run267}268}269270fn for_onnx() -> Self {271use IgnoreCheck::*;272#[cfg(feature = "onnx")]273if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") {274Fail("requires x86_64 or aarch64".into())275} else if !cfg!(target_os = "linux")276&& !cfg!(target_os = "windows")277&& !cfg!(target_os = "macos")278{279Fail("requires linux, windows, or macos".into())280} else {281Run282}283#[cfg(not(feature = "onnx"))]284Ignore("requires the `onnx` feature".into())285}286287fn for_pytorch() -> Self {288use IgnoreCheck::*;289#[cfg(feature = "pytorch")]290if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") {291Fail("requires x86_64 or aarch64".into())292} else if !cfg!(target_os = "linux")293&& !cfg!(target_os = "windows")294&& !cfg!(target_os = "macos")295{296Fail("requires linux, windows, or macos".into())297} else {298Run299}300#[cfg(not(feature = "pytorch"))]301Ignore("requires the `pytorch` feature".into())302}303304fn for_winml() -> IgnoreCheck {305use IgnoreCheck::*;306#[cfg(all(feature = "winml", target_os = "windows"))]307if !cfg!(target_arch = "x86_64") {308Fail("requires x86_64".into())309} else if !cfg!(target_os = "windows") {310Fail("requires windows".into())311} else if let Err(e) = check::winml::is_available() {312Fail(e.to_string().into())313} else {314Run315}316#[cfg(not(all(feature = "winml", target_os = "windows")))]317Ignore("requires the `winml` feature on windows".into())318}319}320321322