Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/wasi-nn/tests/check/pytorch.rs
1693 views
1
use super::{DOWNLOAD_LOCK, artifacts_dir, download};
2
use anyhow::{Context, Result};
3
use std::{env, fs};
4
5
/// Return `Ok` if we find the cached MobileNet test artifacts; this will
6
/// download the artifacts if necessary.
7
pub fn are_artifacts_available() -> Result<()> {
8
let _exclusively_retrieve_artifacts = DOWNLOAD_LOCK.lock().unwrap();
9
const PYTORCH_BASE_URL: &str = "https://github.com/rahulchaphalkar/libtorch-models/releases/download/v0.1/squeezenet1_1.pt";
10
let artifacts_dir = artifacts_dir();
11
if !artifacts_dir.is_dir() {
12
fs::create_dir(&artifacts_dir)?;
13
}
14
15
let local_path = artifacts_dir.join("model.pt");
16
let remote_url = PYTORCH_BASE_URL;
17
if !local_path.is_file() {
18
download(&remote_url, &local_path).with_context(|| "unable to retrieve test artifact")?;
19
} else {
20
println!("> using cached artifact: {}", local_path.display())
21
}
22
23
// Copy image from source tree to artifact directory.
24
let image_path = env::current_dir()?
25
.join("tests")
26
.join("fixtures")
27
.join("kitten.tensor");
28
let dest_path = artifacts_dir.join("kitten.tensor");
29
fs::copy(&image_path, &dest_path)?;
30
Ok(())
31
}
32
33