#![deny(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
mod info;
mod instrument;
mod parse;
mod rewrite;
mod snapshot;
#[cfg(feature = "wasmtime")]
mod wasmtime;
#[cfg(feature = "wasmtime")]
pub use wasmtime::*;
#[cfg(feature = "component-model")]
mod component;
#[cfg(feature = "component-model")]
pub use component::*;
pub use crate::info::ModuleContext;
pub use crate::snapshot::SnapshotVal;
use anyhow::Context;
use std::collections::{HashMap, HashSet};
const DEFAULT_KEEP_INIT_FUNC: bool = false;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "clap", derive(clap::Parser))]
pub struct Wizer {
#[cfg_attr(
feature = "clap",
arg(short = 'f', long, default_value = "wizer-initialize")
)]
init_func: String,
#[cfg_attr(
feature = "clap",
arg(
short = 'r',
long = "rename-func",
alias = "func-rename",
value_name = "dst=src",
value_parser = parse_rename,
),
)]
func_renames: Vec<(String, String)>,
#[cfg_attr(
feature = "clap",
arg(long, require_equals = true, value_name = "true|false")
)]
keep_init_func: Option<Option<bool>>,
}
#[cfg(feature = "clap")]
fn parse_rename(s: &str) -> anyhow::Result<(String, String)> {
let parts: Vec<&str> = s.splitn(2, '=').collect();
if parts.len() != 2 {
anyhow::bail!("must contain exactly one equals character ('=')");
}
Ok((parts[0].into(), parts[1].into()))
}
#[derive(Default)]
struct FuncRenames {
rename_src_to_dst: HashMap<String, String>,
rename_dsts: HashSet<String>,
}
impl FuncRenames {
fn parse(renames: &[(String, String)]) -> anyhow::Result<FuncRenames> {
let mut ret = FuncRenames {
rename_src_to_dst: HashMap::new(),
rename_dsts: HashSet::new(),
};
if renames.is_empty() {
return Ok(ret);
}
for (dst, src) in renames {
if ret.rename_dsts.contains(dst) {
anyhow::bail!("Duplicated function rename dst {dst}");
}
if ret.rename_src_to_dst.contains_key(src) {
anyhow::bail!("Duplicated function rename src {src}");
}
ret.rename_dsts.insert(dst.clone());
ret.rename_src_to_dst.insert(src.clone(), dst.clone());
}
Ok(ret)
}
}
impl Wizer {
pub fn new() -> Self {
Wizer {
init_func: "wizer-initialize".to_string(),
func_renames: vec![],
keep_init_func: None,
}
}
pub fn init_func(&mut self, init_func: impl Into<String>) -> &mut Self {
self.init_func = init_func.into();
self
}
pub fn get_init_func(&self) -> &str {
&self.init_func
}
pub fn func_rename(&mut self, new_name: &str, old_name: &str) -> &mut Self {
self.func_renames
.push((new_name.to_string(), old_name.to_string()));
self
}
pub fn keep_init_func(&mut self, keep: bool) -> &mut Self {
self.keep_init_func = Some(Some(keep));
self
}
pub fn instrument<'a>(&self, wasm: &'a [u8]) -> anyhow::Result<(ModuleContext<'a>, Vec<u8>)> {
self.wasm_validate(&wasm)?;
let mut cx = parse::parse(wasm)?;
for import in cx.imports() {
match import.ty {
wasmparser::TypeRef::Global(_) => {
anyhow::bail!("imported globals are not supported")
}
wasmparser::TypeRef::Table(_) => {
anyhow::bail!("imported tables are not supported")
}
wasmparser::TypeRef::Memory(_) => {
anyhow::bail!("imported memories are not supported")
}
wasmparser::TypeRef::Func(_) => {}
wasmparser::TypeRef::FuncExact(_) => {}
wasmparser::TypeRef::Tag(_) => {}
}
}
let instrumented_wasm = instrument::instrument(&mut cx);
self.debug_assert_valid_wasm(&instrumented_wasm);
Ok((cx, instrumented_wasm))
}
pub async fn snapshot(
&self,
mut cx: ModuleContext<'_>,
instance: &mut impl InstanceState,
) -> anyhow::Result<Vec<u8>> {
let renames = FuncRenames::parse(&self.func_renames)?;
let snapshot = snapshot::snapshot(&cx, instance).await;
let rewritten_wasm = self.rewrite(&mut cx, &snapshot, &renames);
self.debug_assert_valid_wasm(&rewritten_wasm);
Ok(rewritten_wasm)
}
fn debug_assert_valid_wasm(&self, wasm: &[u8]) {
if !cfg!(debug_assertions) {
return;
}
if let Err(error) = self.wasm_validate(&wasm) {
#[cfg(feature = "wasmprinter")]
let wat = wasmprinter::print_bytes(&wasm)
.unwrap_or_else(|e| format!("Disassembling to WAT failed: {}", e));
#[cfg(not(feature = "wasmprinter"))]
let wat = "`wasmprinter` cargo feature is not enabled".to_string();
panic!("instrumented Wasm is not valid: {error:?}\n\nWAT:\n{wat}");
}
}
fn wasm_validate(&self, wasm: &[u8]) -> anyhow::Result<()> {
log::debug!("Validating input Wasm");
wasmparser::Validator::new_with_features(wasmparser::WasmFeatures::all())
.validate_all(wasm)
.context("wasm validation failed")?;
for payload in wasmparser::Parser::new(0).parse_all(wasm) {
match payload? {
wasmparser::Payload::CodeSectionEntry(code) => {
let mut ops = code.get_operators_reader()?;
while !ops.eof() {
match ops.read()? {
wasmparser::Operator::TableCopy { .. } => {
anyhow::bail!("unsupported `table.copy` instruction")
}
wasmparser::Operator::TableInit { .. } => {
anyhow::bail!("unsupported `table.init` instruction")
}
wasmparser::Operator::TableSet { .. } => {
anyhow::bail!("unsupported `table.set` instruction")
}
wasmparser::Operator::TableGrow { .. } => {
anyhow::bail!("unsupported `table.grow` instruction")
}
wasmparser::Operator::TableFill { .. } => {
anyhow::bail!("unsupported `table.fill` instruction")
}
wasmparser::Operator::ElemDrop { .. } => {
anyhow::bail!("unsupported `elem.drop` instruction")
}
wasmparser::Operator::DataDrop { .. } => {
anyhow::bail!("unsupported `data.drop` instruction")
}
wasmparser::Operator::StructSet { .. } => {
anyhow::bail!("unsupported `struct.set` instruction")
}
wasmparser::Operator::ArraySet { .. } => {
anyhow::bail!("unsupported `array.set` instruction")
}
wasmparser::Operator::ArrayFill { .. } => {
anyhow::bail!("unsupported `array.fill` instruction")
}
wasmparser::Operator::ArrayCopy { .. } => {
anyhow::bail!("unsupported `array.copy` instruction")
}
wasmparser::Operator::ArrayInitData { .. } => {
anyhow::bail!("unsupported `array.init_data` instruction")
}
wasmparser::Operator::ArrayInitElem { .. } => {
anyhow::bail!("unsupported `array.init_elem` instruction")
}
_ => continue,
}
}
}
wasmparser::Payload::GlobalSection(globals) => {
for g in globals {
let g = g?.ty;
if !g.mutable {
continue;
}
match g.content_type {
wasmparser::ValType::I32
| wasmparser::ValType::I64
| wasmparser::ValType::F32
| wasmparser::ValType::F64
| wasmparser::ValType::V128 => {}
wasmparser::ValType::Ref(_) => {
anyhow::bail!(
"unsupported mutable global containing a reference type"
)
}
}
}
}
_ => {}
}
}
Ok(())
}
fn get_keep_init_func(&self) -> bool {
match self.keep_init_func {
Some(keep) => keep.unwrap_or(true),
None => DEFAULT_KEEP_INIT_FUNC,
}
}
}
pub trait InstanceState {
fn global_get(&mut self, name: &str) -> impl Future<Output = SnapshotVal> + Send;
fn memory_contents(
&mut self,
name: &str,
contents: impl FnOnce(&[u8]) + Send,
) -> impl Future<Output = ()> + Send;
}