use std::{any::Any, sync::Arc};
use tokio::{
sync::{Mutex, mpsc},
task::JoinHandle,
};
use wasmtime::{
AsContextMut, DebugEvent, DebugHandler, ExnRef, OwnedRooted, Result, Store, StoreContextMut,
Trap,
};
pub struct Debugger<T: Send + 'static> {
inner: Option<JoinHandle<Result<Store<T>>>>,
state: DebuggerState,
in_tx: mpsc::Sender<Command<T>>,
out_rx: mpsc::Receiver<Response>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum DebuggerState {
Running,
Paused,
Queried,
Complete,
}
enum Command<T: 'static> {
Continue,
Query(Box<dyn FnOnce(StoreContextMut<'_, T>) -> Box<dyn Any + Send> + Send>),
}
enum Response {
Paused(DebugRunResult),
QueryResponse(Box<dyn Any + Send>),
Finished,
}
struct HandlerInner<T: Send + 'static> {
in_rx: Mutex<mpsc::Receiver<Command<T>>>,
out_tx: mpsc::Sender<Response>,
}
struct Handler<T: Send + 'static>(Arc<HandlerInner<T>>);
impl<T: Send + 'static> std::clone::Clone for Handler<T> {
fn clone(&self) -> Self {
Handler(self.0.clone())
}
}
impl<T: Send + 'static> DebugHandler for Handler<T> {
type Data = T;
async fn handle(&self, mut store: StoreContextMut<'_, T>, event: DebugEvent<'_>) {
let mut in_rx = self.0.in_rx.lock().await;
let result = match event {
DebugEvent::HostcallError(_) => DebugRunResult::HostcallError,
DebugEvent::CaughtExceptionThrown(exn) => DebugRunResult::CaughtExceptionThrown(exn),
DebugEvent::UncaughtExceptionThrown(exn) => {
DebugRunResult::UncaughtExceptionThrown(exn)
}
DebugEvent::Trap(trap) => DebugRunResult::Trap(trap),
DebugEvent::Breakpoint => DebugRunResult::Breakpoint,
DebugEvent::EpochYield => DebugRunResult::EpochYield,
};
self.0
.out_tx
.send(Response::Paused(result))
.await
.expect("outbound channel closed prematurely");
while let Some(cmd) = in_rx.recv().await {
match cmd {
Command::Query(closure) => {
let result = closure(store.as_context_mut());
self.0
.out_tx
.send(Response::QueryResponse(result))
.await
.expect("outbound channel closed prematurely");
}
Command::Continue => {
break;
}
}
}
}
}
impl<T: Send + 'static> Debugger<T> {
pub fn new<F, I>(mut store: Store<T>, inner: F) -> Debugger<T>
where
I: Future<Output = Result<Store<T>>> + Send + 'static,
F: for<'a> FnOnce(Store<T>) -> I + Send + 'static,
{
let (in_tx, mut in_rx) = mpsc::channel(1);
let (out_tx, out_rx) = mpsc::channel(1);
let inner = tokio::spawn(async move {
match in_rx.recv().await {
Some(cmd) => {
assert!(matches!(cmd, Command::Continue));
}
None => {
wasmtime::bail!("Debugger channel dropped");
}
}
let out_tx_clone = out_tx.clone();
store.set_debug_handler(Handler(Arc::new(HandlerInner {
in_rx: Mutex::new(in_rx),
out_tx,
})));
let result = inner(store).await;
let _ = out_tx_clone.send(Response::Finished).await;
result
});
Debugger {
inner: Some(inner),
state: DebuggerState::Paused,
in_tx,
out_rx,
}
}
pub fn is_complete(&self) -> bool {
match self.state {
DebuggerState::Complete => true,
_ => false,
}
}
pub async fn run(&mut self) -> Result<DebugRunResult> {
log::trace!("running: state is {:?}", self.state);
match self.state {
DebuggerState::Paused => {
log::trace!("sending Continue");
self.in_tx
.send(Command::Continue)
.await
.map_err(|_| wasmtime::format_err!("Failed to send over debug channel"))?;
log::trace!("sent Continue");
self.state = DebuggerState::Running;
}
DebuggerState::Running => {
}
DebuggerState::Queried => {
log::trace!("in Queried; receiving");
let response =
self.out_rx.recv().await.ok_or_else(|| {
wasmtime::format_err!("Premature close of debugger channel")
})?;
log::trace!("in Queried; received, dropping");
assert!(matches!(response, Response::QueryResponse(_)));
self.state = DebuggerState::Paused;
log::trace!("in Paused; sending Continue");
self.in_tx
.send(Command::Continue)
.await
.map_err(|_| wasmtime::format_err!("Failed to send over debug channel"))?;
self.state = DebuggerState::Running;
}
DebuggerState::Complete => {
panic!("Cannot `run()` an already-complete Debugger");
}
}
log::trace!("waiting for response");
let response = self
.out_rx
.recv()
.await
.ok_or_else(|| wasmtime::format_err!("Premature close of debugger channel"))?;
match response {
Response::Finished => {
log::trace!("got Finished");
self.state = DebuggerState::Complete;
Ok(DebugRunResult::Finished)
}
Response::Paused(result) => {
log::trace!("got Paused");
self.state = DebuggerState::Paused;
Ok(result)
}
Response::QueryResponse(_) => {
panic!("Invalid debug response");
}
}
}
pub async fn finish(&mut self) -> Result<()> {
if self.is_complete() {
return Ok(());
}
loop {
match self.run().await? {
DebugRunResult::Finished => break,
e => {
log::trace!("finish: event {e:?}");
}
}
}
assert!(self.is_complete());
Ok(())
}
pub async fn with_store<
F: FnOnce(StoreContextMut<'_, T>) -> R + Send + 'static,
R: Send + 'static,
>(
&mut self,
f: F,
) -> Result<R> {
assert!(!self.is_complete());
match self.state {
DebuggerState::Queried => {
let response =
self.out_rx.recv().await.ok_or_else(|| {
wasmtime::format_err!("Premature close of debugger channel")
})?;
assert!(matches!(response, Response::QueryResponse(_)));
self.state = DebuggerState::Paused;
}
DebuggerState::Running => {
panic!("Cannot query in Running state");
}
DebuggerState::Complete => {
panic!("Cannot query when complete");
}
DebuggerState::Paused => {
}
}
self.in_tx
.send(Command::Query(Box::new(|store| Box::new(f(store)))))
.await
.map_err(|_| wasmtime::format_err!("Premature close of debugger channel"))?;
self.state = DebuggerState::Queried;
let response = self
.out_rx
.recv()
.await
.ok_or_else(|| wasmtime::format_err!("Premature close of debugger channel"))?;
let Response::QueryResponse(resp) = response else {
wasmtime::bail!("Incorrect response from debugger task");
};
self.state = DebuggerState::Paused;
Ok(*resp.downcast::<R>().expect("type mismatch"))
}
pub async fn take_store(&mut self) -> Result<Option<Store<T>>> {
match self.state {
DebuggerState::Complete => {
let inner = match self.inner.take() {
Some(inner) => inner,
None => return Ok(None),
};
let mut store = inner.await??;
store.clear_debug_handler();
Ok(Some(store))
}
_ => panic!("Invalid state: debugger not yet complete"),
}
}
}
#[derive(Debug)]
pub enum DebugRunResult {
Finished,
HostcallError,
EpochYield,
CaughtExceptionThrown(OwnedRooted<ExnRef>),
UncaughtExceptionThrown(OwnedRooted<ExnRef>),
Trap(Trap),
Breakpoint,
}
#[cfg(test)]
mod test {
use super::*;
use wasmtime::*;
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn basic_debugger() -> wasmtime::Result<()> {
let _ = env_logger::try_init();
let mut config = Config::new();
config.guest_debug(true);
let engine = Engine::new(&config)?;
let module = Module::new(
&engine,
r#"
(module
(func (export "main") (param i32 i32) (result i32)
local.get 0
local.get 1
i32.add))
"#,
)?;
let mut store = Store::new(&engine, ());
let instance = Instance::new_async(&mut store, &module, &[]).await?;
let main = instance.get_func(&mut store, "main").unwrap();
let mut debugger = Debugger::new(store, move |mut store| async move {
let mut results = [Val::I32(0)];
store.edit_breakpoints().unwrap().single_step(true).unwrap();
main.call_async(&mut store, &[Val::I32(1), Val::I32(2)], &mut results[..])
.await?;
assert_eq!(results[0].unwrap_i32(), 3);
main.call_async(&mut store, &[Val::I32(3), Val::I32(4)], &mut results[..])
.await?;
assert_eq!(results[0].unwrap_i32(), 7);
Ok(store)
});
let event = debugger.run().await?;
assert!(matches!(event, DebugRunResult::Breakpoint));
debugger
.with_store(|store| {
let mut frame = store.debug_frames().unwrap();
assert!(!frame.done());
assert_eq!(frame.wasm_function_index_and_pc().unwrap().0.as_u32(), 0);
assert_eq!(frame.wasm_function_index_and_pc().unwrap().1, 36);
assert_eq!(frame.num_locals(), 2);
assert_eq!(frame.num_stacks(), 0);
assert_eq!(frame.local(0).unwrap_i32(), 1);
assert_eq!(frame.local(1).unwrap_i32(), 2);
assert_eq!(frame.move_to_parent(), FrameParentResult::SameActivation);
assert!(frame.done());
})
.await?;
let event = debugger.run().await?;
assert!(matches!(event, DebugRunResult::Breakpoint));
debugger
.with_store(|store| {
let mut frame = store.debug_frames().unwrap();
assert!(!frame.done());
assert_eq!(frame.wasm_function_index_and_pc().unwrap().0.as_u32(), 0);
assert_eq!(frame.wasm_function_index_and_pc().unwrap().1, 38);
assert_eq!(frame.num_locals(), 2);
assert_eq!(frame.num_stacks(), 1);
assert_eq!(frame.local(0).unwrap_i32(), 1);
assert_eq!(frame.local(1).unwrap_i32(), 2);
assert_eq!(frame.stack(0).unwrap_i32(), 1);
assert_eq!(frame.move_to_parent(), FrameParentResult::SameActivation);
assert!(frame.done());
})
.await?;
let event = debugger.run().await?;
assert!(matches!(event, DebugRunResult::Breakpoint));
debugger
.with_store(|store| {
let mut frame = store.debug_frames().unwrap();
assert!(!frame.done());
assert_eq!(frame.wasm_function_index_and_pc().unwrap().0.as_u32(), 0);
assert_eq!(frame.wasm_function_index_and_pc().unwrap().1, 40);
assert_eq!(frame.num_locals(), 2);
assert_eq!(frame.num_stacks(), 2);
assert_eq!(frame.local(0).unwrap_i32(), 1);
assert_eq!(frame.local(1).unwrap_i32(), 2);
assert_eq!(frame.stack(0).unwrap_i32(), 1);
assert_eq!(frame.stack(1).unwrap_i32(), 2);
assert_eq!(frame.move_to_parent(), FrameParentResult::SameActivation);
assert!(frame.done());
})
.await?;
let event = debugger.run().await?;
assert!(matches!(event, DebugRunResult::Breakpoint));
debugger
.with_store(|store| {
let mut frame = store.debug_frames().unwrap();
assert!(!frame.done());
assert_eq!(frame.wasm_function_index_and_pc().unwrap().0.as_u32(), 0);
assert_eq!(frame.wasm_function_index_and_pc().unwrap().1, 41);
assert_eq!(frame.num_locals(), 2);
assert_eq!(frame.num_stacks(), 1);
assert_eq!(frame.local(0).unwrap_i32(), 1);
assert_eq!(frame.local(1).unwrap_i32(), 2);
assert_eq!(frame.stack(0).unwrap_i32(), 3);
assert_eq!(frame.move_to_parent(), FrameParentResult::SameActivation);
assert!(frame.done());
})
.await?;
debugger
.with_store(|store| {
store
.edit_breakpoints()
.unwrap()
.single_step(false)
.unwrap();
})
.await?;
let event = debugger.run().await?;
assert!(matches!(event, DebugRunResult::Finished));
assert!(debugger.is_complete());
let mut store = debugger.take_store().await?.unwrap();
let mut results = [Val::I32(0)];
main.call_async(&mut store, &[Val::I32(10), Val::I32(20)], &mut results[..])
.await?;
assert_eq!(results[0].unwrap_i32(), 30);
Ok(())
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn early_finish() -> Result<()> {
let _ = env_logger::try_init();
let mut config = Config::new();
config.guest_debug(true);
let engine = Engine::new(&config)?;
let module = Module::new(
&engine,
r#"
(module
(func (export "main") (param i32 i32) (result i32)
local.get 0
local.get 1
i32.add))
"#,
)?;
let mut store = Store::new(&engine, ());
let instance = Instance::new_async(&mut store, &module, &[]).await?;
let main = instance.get_func(&mut store, "main").unwrap();
let mut debugger = Debugger::new(store, move |mut store| async move {
let mut results = [Val::I32(0)];
store.edit_breakpoints().unwrap().single_step(true).unwrap();
main.call_async(&mut store, &[Val::I32(1), Val::I32(2)], &mut results[..])
.await?;
assert_eq!(results[0].unwrap_i32(), 3);
Ok(store)
});
debugger.finish().await?;
assert!(debugger.is_complete());
Ok(())
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn drop_debugger_and_store() -> Result<()> {
let _ = env_logger::try_init();
let mut config = Config::new();
config.guest_debug(true);
let engine = Engine::new(&config)?;
let module = Module::new(
&engine,
r#"
(module
(func (export "main") (param i32 i32) (result i32)
local.get 0
local.get 1
i32.add))
"#,
)?;
let mut store = Store::new(&engine, ());
let instance = Instance::new_async(&mut store, &module, &[]).await?;
let main = instance.get_func(&mut store, "main").unwrap();
let mut debugger = Debugger::new(store, move |mut store| async move {
let mut results = [Val::I32(0)];
store.edit_breakpoints().unwrap().single_step(true).unwrap();
main.call_async(&mut store, &[Val::I32(1), Val::I32(2)], &mut results[..])
.await?;
assert_eq!(results[0].unwrap_i32(), 3);
Ok(store)
});
let _ = debugger.run().await?;
Ok(())
}
}