Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bytecodealliance
GitHub Repository: bytecodealliance/wasmtime
Path: blob/main/crates/core/src/undo.rs
3071 views
1
//! Helpers for undoing partial side effects when their larger operation fails.
2
3
use core::{fmt, mem, ops};
4
5
/// An RAII guard to rollback and undo something on (early) drop.
6
///
7
/// Dereferences to its inner `T` and its undo function is given the `T` on
8
/// drop.
9
///
10
/// When all of the changes that need to happen together have happened, you can
11
/// call `Undo::commit` to disable the guard and commit the associated side
12
/// effects.
13
///
14
/// # Example
15
///
16
/// ```
17
/// use std::cell::Cell;
18
/// use wasmtime_internal_core::{error::Result, undo::Undo};
19
///
20
/// /// Some big ball of state that must always be coherent.
21
/// pub struct Context {
22
/// // ...
23
/// }
24
///
25
/// impl Context {
26
/// /// Perform some incremental mutation to `self`, which might not leave
27
/// /// it in a valid state unless its whole batch of work is completed.
28
/// fn do_thing(&mut self, arg: u32) -> Result<()> {
29
/// # let _ = arg;
30
/// # todo!()
31
/// // ...
32
/// }
33
///
34
/// /// Undo the side effects of `self.do_thing(arg)` for when we need to
35
/// /// roll back mutations.
36
/// fn undo_thing(&mut self, arg: u32) {
37
/// # let _ = arg;
38
/// // ...
39
/// }
40
///
41
/// /// Call `self.do_thing(arg)` for each `arg` in `args`.
42
/// ///
43
/// /// However, if any `self.do_thing(arg)` call fails, make sure that
44
/// /// we roll back to the original state by calling `self.undo_thing(arg)`
45
/// /// for all the `self.do_thing(arg)` calls that already succeeded. This
46
/// /// way we never leave `self` in a state where things got half-done.
47
/// pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> {
48
/// // Counter for our progress, so that we know how much to work undo upon
49
/// // failure.
50
/// let num_things_done = Cell::new(0);
51
///
52
/// // Wrap the `Context` in an `Undo` that rolls back our side effects if
53
/// // we early-exit this function via `?`-propagation or panic unwinding.
54
/// let mut ctx = Undo::new(self, |ctx| {
55
/// for arg in args.iter().take(num_things_done.get()) {
56
/// ctx.undo_thing(*arg);
57
/// }
58
/// });
59
///
60
/// // Do each piece of work!
61
/// for arg in args {
62
/// // Note: if this call returns an error that is `?`-propagated or
63
/// // triggers unwinding by panicking, then the work performed thus
64
/// // far will be rolled back when `ctx` is dropped.
65
/// ctx.do_thing(*arg)?;
66
///
67
/// // Update how much work has been completed.
68
/// num_things_done.set(num_things_done.get() + 1);
69
/// }
70
///
71
/// // We completed all of the work, so commit the `Undo` guard and
72
/// // disable its cleanup function.
73
/// Undo::commit(ctx);
74
///
75
/// Ok(())
76
/// }
77
/// }
78
/// ```
79
#[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \
80
to disable"]
81
pub struct Undo<T, F>
82
where
83
F: FnOnce(T),
84
{
85
inner: mem::ManuallyDrop<T>,
86
undo: mem::ManuallyDrop<F>,
87
}
88
89
impl<T, F> Drop for Undo<T, F>
90
where
91
F: FnOnce(T),
92
{
93
fn drop(&mut self) {
94
// Safety: These `ManuallyDrop` fields will not be used again.
95
let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) };
96
let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) };
97
undo(inner);
98
}
99
}
100
101
impl<T, F> fmt::Debug for Undo<T, F>
102
where
103
F: FnOnce(T),
104
T: fmt::Debug,
105
{
106
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107
f.debug_struct("Undo")
108
.field("inner", &self.inner)
109
.field("undo", &"..")
110
.finish()
111
}
112
}
113
114
impl<T, F> ops::Deref for Undo<T, F>
115
where
116
F: FnOnce(T),
117
{
118
type Target = T;
119
120
fn deref(&self) -> &Self::Target {
121
&self.inner
122
}
123
}
124
125
impl<T, F> ops::DerefMut for Undo<T, F>
126
where
127
F: FnOnce(T),
128
{
129
fn deref_mut(&mut self) -> &mut Self::Target {
130
&mut self.inner
131
}
132
}
133
134
impl<T, F> Undo<T, F>
135
where
136
F: FnOnce(T),
137
{
138
/// Create a new `Undo` guard.
139
///
140
/// This guard will wrap the given `inner` object and call `undo(inner)`
141
/// when dropped, unless the guard is disabled via `Undo::commit`.
142
pub fn new(inner: T, undo: F) -> Self {
143
Self {
144
inner: mem::ManuallyDrop::new(inner),
145
undo: mem::ManuallyDrop::new(undo),
146
}
147
}
148
149
/// Disable this `Undo` and return its inner value.
150
///
151
/// This `Undo`'s cleanup function will never be called.
152
pub fn commit(guard: Self) -> T {
153
let mut guard = mem::ManuallyDrop::new(guard);
154
155
// Safety: These `ManuallyDrop` fields will not be used again.
156
unsafe {
157
// Make sure to drop `undo`, even though we aren't calling it, to
158
// avoid leaking closed-over `Arc`s, for example.
159
mem::ManuallyDrop::drop(&mut guard.undo);
160
161
mem::ManuallyDrop::take(&mut guard.inner)
162
}
163
}
164
}
165
166
#[cfg(all(test, feature = "std"))]
167
mod tests {
168
use super::*;
169
use crate::error::{Result, ensure};
170
use core::{cell::Cell, cmp};
171
use std::{panic, string::ToString};
172
173
#[derive(Default)]
174
struct Counter {
175
value: u32,
176
max_value_seen: u32,
177
}
178
179
impl Counter {
180
fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
181
f(self)?;
182
self.value += 1;
183
self.max_value_seen = cmp::max(self.max_value_seen, self.value);
184
Ok(())
185
}
186
187
fn dec(&mut self) {
188
self.value -= 1;
189
}
190
191
fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
192
let i = Cell::new(0);
193
194
let mut counter = Undo::new(self, |counter| {
195
for _ in 0..i.get() {
196
counter.dec();
197
}
198
});
199
200
for _ in 0..n {
201
counter.inc(&mut f)?;
202
i.set(i.get() + 1);
203
}
204
205
Undo::commit(counter);
206
Ok(())
207
}
208
}
209
210
#[test]
211
fn error_propagation() {
212
let mut counter = Counter::default();
213
let result = counter.inc_n(10, |c| {
214
ensure!(c.value < 5, "uh oh");
215
Ok(())
216
});
217
assert_eq!(result.unwrap_err().to_string(), "uh oh");
218
assert_eq!(counter.value, 0);
219
assert_eq!(counter.max_value_seen, 5);
220
}
221
222
#[test]
223
fn panic_unwind() {
224
let mut counter = Counter::default();
225
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
226
counter.inc_n(10, |c| {
227
assert!(c.value < 5);
228
Ok(())
229
})
230
}));
231
assert!(result.is_err());
232
assert_eq!(counter.value, 0);
233
assert_eq!(counter.max_value_seen, 5);
234
}
235
236
#[test]
237
fn commit() {
238
let mut counter = Counter::default();
239
let result = counter.inc_n(10, |_| Ok(()));
240
assert!(result.is_ok());
241
assert_eq!(counter.value, 10);
242
assert_eq!(counter.max_value_seen, 10);
243
}
244
}
245
246