Path: blob/main/crates/wasi-nn/tests/check/pytorch.rs
1693 views
use super::{DOWNLOAD_LOCK, artifacts_dir, download};1use anyhow::{Context, Result};2use std::{env, fs};34/// Return `Ok` if we find the cached MobileNet test artifacts; this will5/// download the artifacts if necessary.6pub fn are_artifacts_available() -> Result<()> {7let _exclusively_retrieve_artifacts = DOWNLOAD_LOCK.lock().unwrap();8const PYTORCH_BASE_URL: &str = "https://github.com/rahulchaphalkar/libtorch-models/releases/download/v0.1/squeezenet1_1.pt";9let artifacts_dir = artifacts_dir();10if !artifacts_dir.is_dir() {11fs::create_dir(&artifacts_dir)?;12}1314let local_path = artifacts_dir.join("model.pt");15let remote_url = PYTORCH_BASE_URL;16if !local_path.is_file() {17download(&remote_url, &local_path).with_context(|| "unable to retrieve test artifact")?;18} else {19println!("> using cached artifact: {}", local_path.display())20}2122// Copy image from source tree to artifact directory.23let image_path = env::current_dir()?24.join("tests")25.join("fixtures")26.join("kitten.tensor");27let dest_path = artifacts_dir.join("kitten.tensor");28fs::copy(&image_path, &dest_path)?;29Ok(())30}313233