use crate::{FuncRenames, SnapshotVal, Wizer, info::ModuleContext, snapshot::Snapshot};
use std::cell::Cell;
use std::convert::TryFrom;
use wasm_encoder::reencode::{Reencode, RoundtripReencoder};
use wasm_encoder::{ConstExpr, SectionId};
impl Wizer {
pub(crate) fn rewrite(
&self,
module: &mut ModuleContext<'_>,
snapshot: &Snapshot,
renames: &FuncRenames,
) -> Vec<u8> {
log::debug!("Rewriting input Wasm to pre-initialized state");
let mut encoder = wasm_encoder::Module::new();
let has_wasi_initialize = module.has_wasi_initialize();
let add_data_segments = |data_section: &mut wasm_encoder::DataSection| {
for seg in &snapshot.data_segments {
let offset = if seg.is64 {
ConstExpr::i64_const(seg.offset.cast_signed())
} else {
ConstExpr::i32_const(u32::try_from(seg.offset).unwrap().cast_signed())
};
data_section.active(seg.memory_index, &offset, seg.data.iter().copied());
}
};
let added_data_section = Cell::new(false);
let add_data_section = |encoder: &mut wasm_encoder::Module| {
if added_data_section.get() {
return;
}
added_data_section.set(true);
let mut data_section = wasm_encoder::DataSection::new();
add_data_segments(&mut data_section);
encoder.section(&data_section);
};
for section in module.raw_sections() {
match section {
s if is_name_section(s) => {
add_data_section(&mut encoder);
encoder.section(s);
}
s if s.id == u8::from(SectionId::Memory) => {
let mut memories = wasm_encoder::MemorySection::new();
assert_eq!(module.defined_memories_len(), snapshot.memory_mins.len());
for ((_, mem), new_min) in module
.defined_memories()
.zip(snapshot.memory_mins.iter().copied())
{
let mut mem = RoundtripReencoder.memory_type(mem).unwrap();
mem.minimum = new_min;
memories.memory(mem);
}
encoder.section(&memories);
}
s if s.id == u8::from(SectionId::Global) => {
let original_globals = wasmparser::GlobalSectionReader::new(
wasmparser::BinaryReader::new(s.data, 0),
)
.unwrap();
let mut globals = wasm_encoder::GlobalSection::new();
let mut snapshot = snapshot.globals.iter();
for ((_, glob_ty, export_name), global) in
module.defined_globals().zip(original_globals)
{
let global = global.unwrap();
if export_name.is_some() {
assert!(glob_ty.mutable);
let (_, val) = snapshot.next().unwrap();
let init = match val {
SnapshotVal::I32(x) => ConstExpr::i32_const(*x),
SnapshotVal::I64(x) => ConstExpr::i64_const(*x),
SnapshotVal::F32(x) => {
ConstExpr::f32_const(wasm_encoder::Ieee32::new(*x))
}
SnapshotVal::F64(x) => {
ConstExpr::f64_const(wasm_encoder::Ieee64::new(*x))
}
SnapshotVal::V128(x) => ConstExpr::v128_const(x.cast_signed()),
};
let glob_ty = RoundtripReencoder.global_type(glob_ty).unwrap();
globals.global(glob_ty, &init);
} else {
assert!(!glob_ty.mutable);
RoundtripReencoder
.parse_global(&mut globals, global)
.unwrap();
};
}
encoder.section(&globals);
}
s if s.id == u8::from(SectionId::Export) => {
let mut exports = wasm_encoder::ExportSection::new();
for export in module.exports() {
if (export.name == self.get_init_func() && !self.get_keep_init_func())
|| (has_wasi_initialize && export.name == "_initialize")
{
continue;
}
if !renames.rename_src_to_dst.contains_key(export.name)
&& renames.rename_dsts.contains(export.name)
{
continue;
}
let field = renames
.rename_src_to_dst
.get(export.name)
.map_or(export.name, |f| f.as_str());
let kind = RoundtripReencoder.export_kind(export.kind).unwrap();
exports.export(field, kind, export.index);
}
encoder.section(&exports);
}
s if s.id == u8::from(SectionId::Start) => {
continue;
}
s if s.id == u8::from(SectionId::DataCount) => {
let mut data = wasmparser::BinaryReader::new(s.data, 0);
let prev = data.read_var_u32().unwrap();
assert!(data.eof());
encoder.section(&wasm_encoder::DataCountSection {
count: prev + u32::try_from(snapshot.data_segments.len()).unwrap(),
});
}
s if s.id == u8::from(SectionId::Data) => {
let mut section = wasm_encoder::DataSection::new();
let data = wasmparser::BinaryReader::new(s.data, 0);
for data in wasmparser::DataSectionReader::new(data).unwrap() {
let data = data.unwrap();
match data.kind {
wasmparser::DataKind::Active { .. } => {
section.passive([]);
}
wasmparser::DataKind::Passive => {
section.passive(data.data.iter().copied());
}
}
}
add_data_segments(&mut section);
encoder.section(§ion);
added_data_section.set(true);
}
s => {
encoder.section(s);
}
}
}
add_data_section(&mut encoder);
encoder.finish()
}
}
fn is_name_section(s: &wasm_encoder::RawSection) -> bool {
s.id == u8::from(SectionId::Custom) && {
let mut reader = wasmparser::BinaryReader::new(s.data, 0);
matches!(reader.read_string(), Ok("name"))
}
}