//! Helpers for undoing partial side effects when their larger operation fails.12use core::{fmt, mem, ops};34/// An RAII guard to rollback and undo something on (early) drop.5///6/// Dereferences to its inner `T` and its undo function is given the `T` on7/// drop.8///9/// When all of the changes that need to happen together have happened, you can10/// call `Undo::commit` to disable the guard and commit the associated side11/// effects.12///13/// # Example14///15/// ```16/// use std::cell::Cell;17/// use wasmtime_internal_core::{error::Result, undo::Undo};18///19/// /// Some big ball of state that must always be coherent.20/// pub struct Context {21/// // ...22/// }23///24/// impl Context {25/// /// Perform some incremental mutation to `self`, which might not leave26/// /// it in a valid state unless its whole batch of work is completed.27/// fn do_thing(&mut self, arg: u32) -> Result<()> {28/// # let _ = arg;29/// # todo!()30/// // ...31/// }32///33/// /// Undo the side effects of `self.do_thing(arg)` for when we need to34/// /// roll back mutations.35/// fn undo_thing(&mut self, arg: u32) {36/// # let _ = arg;37/// // ...38/// }39///40/// /// Call `self.do_thing(arg)` for each `arg` in `args`.41/// ///42/// /// However, if any `self.do_thing(arg)` call fails, make sure that43/// /// we roll back to the original state by calling `self.undo_thing(arg)`44/// /// for all the `self.do_thing(arg)` calls that already succeeded. This45/// /// way we never leave `self` in a state where things got half-done.46/// pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> {47/// // Counter for our progress, so that we know how much to work undo upon48/// // failure.49/// let num_things_done = Cell::new(0);50///51/// // Wrap the `Context` in an `Undo` that rolls back our side effects if52/// // we early-exit this function via `?`-propagation or panic unwinding.53/// let mut ctx = Undo::new(self, |ctx| {54/// for arg in args.iter().take(num_things_done.get()) {55/// ctx.undo_thing(*arg);56/// }57/// });58///59/// // Do each piece of work!60/// for arg in args {61/// // Note: if this call returns an error that is `?`-propagated or62/// // triggers unwinding by panicking, then the work performed thus63/// // far will be rolled back when `ctx` is dropped.64/// ctx.do_thing(*arg)?;65///66/// // Update how much work has been completed.67/// num_things_done.set(num_things_done.get() + 1);68/// }69///70/// // We completed all of the work, so commit the `Undo` guard and71/// // disable its cleanup function.72/// Undo::commit(ctx);73///74/// Ok(())75/// }76/// }77/// ```78#[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \79to disable"]80pub struct Undo<T, F>81where82F: FnOnce(T),83{84inner: mem::ManuallyDrop<T>,85undo: mem::ManuallyDrop<F>,86}8788impl<T, F> Drop for Undo<T, F>89where90F: FnOnce(T),91{92fn drop(&mut self) {93// Safety: These `ManuallyDrop` fields will not be used again.94let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) };95let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) };96undo(inner);97}98}99100impl<T, F> fmt::Debug for Undo<T, F>101where102F: FnOnce(T),103T: fmt::Debug,104{105fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {106f.debug_struct("Undo")107.field("inner", &self.inner)108.field("undo", &"..")109.finish()110}111}112113impl<T, F> ops::Deref for Undo<T, F>114where115F: FnOnce(T),116{117type Target = T;118119fn deref(&self) -> &Self::Target {120&self.inner121}122}123124impl<T, F> ops::DerefMut for Undo<T, F>125where126F: FnOnce(T),127{128fn deref_mut(&mut self) -> &mut Self::Target {129&mut self.inner130}131}132133impl<T, F> Undo<T, F>134where135F: FnOnce(T),136{137/// Create a new `Undo` guard.138///139/// This guard will wrap the given `inner` object and call `undo(inner)`140/// when dropped, unless the guard is disabled via `Undo::commit`.141pub fn new(inner: T, undo: F) -> Self {142Self {143inner: mem::ManuallyDrop::new(inner),144undo: mem::ManuallyDrop::new(undo),145}146}147148/// Disable this `Undo` and return its inner value.149///150/// This `Undo`'s cleanup function will never be called.151pub fn commit(guard: Self) -> T {152let mut guard = mem::ManuallyDrop::new(guard);153154// Safety: These `ManuallyDrop` fields will not be used again.155unsafe {156// Make sure to drop `undo`, even though we aren't calling it, to157// avoid leaking closed-over `Arc`s, for example.158mem::ManuallyDrop::drop(&mut guard.undo);159160mem::ManuallyDrop::take(&mut guard.inner)161}162}163}164165#[cfg(all(test, feature = "std"))]166mod tests {167use super::*;168use crate::error::{Result, ensure};169use core::{cell::Cell, cmp};170use std::{panic, string::ToString};171172#[derive(Default)]173struct Counter {174value: u32,175max_value_seen: u32,176}177178impl Counter {179fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {180f(self)?;181self.value += 1;182self.max_value_seen = cmp::max(self.max_value_seen, self.value);183Ok(())184}185186fn dec(&mut self) {187self.value -= 1;188}189190fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {191let i = Cell::new(0);192193let mut counter = Undo::new(self, |counter| {194for _ in 0..i.get() {195counter.dec();196}197});198199for _ in 0..n {200counter.inc(&mut f)?;201i.set(i.get() + 1);202}203204Undo::commit(counter);205Ok(())206}207}208209#[test]210fn error_propagation() {211let mut counter = Counter::default();212let result = counter.inc_n(10, |c| {213ensure!(c.value < 5, "uh oh");214Ok(())215});216assert_eq!(result.unwrap_err().to_string(), "uh oh");217assert_eq!(counter.value, 0);218assert_eq!(counter.max_value_seen, 5);219}220221#[test]222fn panic_unwind() {223let mut counter = Counter::default();224let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {225counter.inc_n(10, |c| {226assert!(c.value < 5);227Ok(())228})229}));230assert!(result.is_err());231assert_eq!(counter.value, 0);232assert_eq!(counter.max_value_seen, 5);233}234235#[test]236fn commit() {237let mut counter = Counter::default();238let result = counter.inc_n(10, |_| Ok(()));239assert!(result.is_ok());240assert_eq!(counter.value, 10);241assert_eq!(counter.max_value_seen, 10);242}243}244245246