#![no_std]
#[cfg(any(feature = "std", unix, windows))]
#[macro_use]
extern crate std;
extern crate alloc;
use alloc::boxed::Box;
use anyhow::Error;
use core::cell::Cell;
use core::marker::PhantomData;
use core::ops::Range;
cfg_if::cfg_if! {
if #[cfg(not(feature = "std"))] {
mod nostd;
use nostd as imp;
} else if #[cfg(miri)] {
mod miri;
use miri as imp;
} else if #[cfg(windows)] {
mod windows;
use windows as imp;
} else if #[cfg(unix)] {
mod unix;
use unix as imp;
} else {
compile_error!("fibers are not supported on this platform");
}
}
#[cfg(any(unix, not(feature = "std")))]
pub(crate) mod stackswitch;
pub struct FiberStack(imp::FiberStack);
fn _assert_send_sync() {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
_assert_send::<FiberStack>();
_assert_sync::<FiberStack>();
}
pub type Result<T, E = imp::Error> = core::result::Result<T, E>;
impl FiberStack {
pub fn new(size: usize, zeroed: bool) -> Result<Self> {
Ok(Self(imp::FiberStack::new(size, zeroed)?))
}
pub fn from_custom(custom: Box<dyn RuntimeFiberStack>) -> Result<Self> {
Ok(Self(imp::FiberStack::from_custom(custom)?))
}
pub unsafe fn from_raw_parts(bottom: *mut u8, guard_size: usize, len: usize) -> Result<Self> {
Ok(Self(unsafe {
imp::FiberStack::from_raw_parts(bottom, guard_size, len)?
}))
}
pub fn top(&self) -> Option<*mut u8> {
self.0.top()
}
pub fn range(&self) -> Option<Range<usize>> {
self.0.range()
}
pub fn is_from_raw_parts(&self) -> bool {
self.0.is_from_raw_parts()
}
pub fn guard_range(&self) -> Option<Range<*mut u8>> {
self.0.guard_range()
}
}
pub unsafe trait RuntimeFiberStackCreator: Send + Sync {
fn new_stack(&self, size: usize, zeroed: bool) -> Result<Box<dyn RuntimeFiberStack>, Error>;
}
pub unsafe trait RuntimeFiberStack: Send + Sync {
fn top(&self) -> *mut u8;
fn range(&self) -> Range<usize>;
fn guard_range(&self) -> Range<*mut u8>;
}
pub struct Fiber<'a, Resume, Yield, Return> {
stack: Option<FiberStack>,
inner: imp::Fiber,
done: Cell<bool>,
_phantom: PhantomData<&'a (Resume, Yield, Return)>,
}
pub struct Suspend<Resume, Yield, Return> {
inner: imp::Suspend,
_phantom: PhantomData<(Resume, Yield, Return)>,
}
enum RunResult<Resume, Yield, Return> {
Executing,
Resuming(Resume),
Yield(Yield),
Returned(Return),
#[cfg(feature = "std")]
Panicked(Box<dyn core::any::Any + Send>),
}
impl<'a, Resume, Yield, Return> Fiber<'a, Resume, Yield, Return> {
pub fn new(
stack: FiberStack,
func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return + 'a,
) -> Result<Self> {
let inner = imp::Fiber::new(&stack.0, func)?;
Ok(Self {
stack: Some(stack),
inner,
done: Cell::new(false),
_phantom: PhantomData,
})
}
pub fn resume(&self, val: Resume) -> Result<Return, Yield> {
assert!(!self.done.replace(true), "cannot resume a finished fiber");
let result = Cell::new(RunResult::Resuming(val));
self.inner.resume(&self.stack().0, &result);
match result.into_inner() {
RunResult::Resuming(_) | RunResult::Executing => unreachable!(),
RunResult::Yield(y) => {
self.done.set(false);
Err(y)
}
RunResult::Returned(r) => Ok(r),
#[cfg(feature = "std")]
RunResult::Panicked(_payload) => {
use std::panic;
panic::resume_unwind(_payload);
}
}
}
pub fn done(&self) -> bool {
self.done.get()
}
pub fn stack(&self) -> &FiberStack {
self.stack.as_ref().unwrap()
}
pub fn into_stack(mut self) -> FiberStack {
assert!(self.done());
self.stack.take().unwrap()
}
}
impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
pub fn suspend(&mut self, value: Yield) -> Resume {
self.inner
.switch::<Resume, Yield, Return>(RunResult::Yield(value))
}
fn execute(
inner: imp::Suspend,
initial: Resume,
func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return,
) {
let mut suspend = Suspend {
inner,
_phantom: PhantomData,
};
#[cfg(feature = "std")]
let result = {
use std::panic::{self, AssertUnwindSafe};
let result = panic::catch_unwind(AssertUnwindSafe(|| (func)(initial, &mut suspend)));
match result {
Ok(result) => RunResult::Returned(result),
Err(panic) => RunResult::Panicked(panic),
}
};
#[cfg(not(feature = "std"))]
let result = RunResult::Returned((func)(initial, &mut suspend));
suspend.inner.exit::<Resume, Yield, Return>(result);
}
}
impl<A, B, C> Drop for Fiber<'_, A, B, C> {
fn drop(&mut self) {
debug_assert!(self.done.get(), "fiber dropped without finishing");
unsafe {
self.inner.drop::<A, B, C>();
}
}
}
#[cfg(all(test))]
mod tests {
use super::{Fiber, FiberStack};
use alloc::string::ToString;
use std::cell::Cell;
use std::rc::Rc;
fn fiber_stack(size: usize) -> FiberStack {
FiberStack::new(size, false).unwrap()
}
#[test]
fn small_stacks() {
Fiber::<(), (), ()>::new(fiber_stack(0), |_, _| {})
.unwrap()
.resume(())
.unwrap();
Fiber::<(), (), ()>::new(fiber_stack(1), |_, _| {})
.unwrap()
.resume(())
.unwrap();
}
#[test]
fn smoke() {
let hit = Rc::new(Cell::new(false));
let hit2 = hit.clone();
let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |_, _| {
hit2.set(true);
})
.unwrap();
assert!(!hit.get());
fiber.resume(()).unwrap();
assert!(hit.get());
}
#[test]
fn suspend_and_resume() {
let hit = Rc::new(Cell::new(false));
let hit2 = hit.clone();
let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |_, s| {
s.suspend(());
hit2.set(true);
s.suspend(());
})
.unwrap();
assert!(!hit.get());
assert!(fiber.resume(()).is_err());
assert!(!hit.get());
assert!(fiber.resume(()).is_err());
assert!(hit.get());
assert!(fiber.resume(()).is_ok());
assert!(hit.get());
}
#[test]
fn backtrace_traces_to_host() {
#[inline(never)]
fn look_for_me() {
run_test();
}
fn assert_contains_host() {
let trace = backtrace::Backtrace::new();
println!("{trace:?}");
assert!(
trace
.frames()
.iter()
.flat_map(|f| f.symbols())
.filter_map(|s| Some(s.name()?.to_string()))
.any(|s| s.contains("look_for_me"))
|| cfg!(windows)
|| cfg!(all(target_os = "macos", target_arch = "aarch64"))
|| cfg!(target_arch = "arm")
|| cfg!(asan)
|| cfg!(miri)
);
}
fn run_test() {
let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |(), s| {
assert_contains_host();
s.suspend(());
assert_contains_host();
s.suspend(());
assert_contains_host();
})
.unwrap();
assert!(fiber.resume(()).is_err());
assert!(fiber.resume(()).is_err());
assert!(fiber.resume(()).is_ok());
}
look_for_me();
}
#[test]
#[cfg(feature = "std")]
fn panics_propagated() {
use std::panic::{self, AssertUnwindSafe};
let a = Rc::new(Cell::new(false));
let b = SetOnDrop(a.clone());
let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |(), _s| {
let _ = &b;
panic!();
})
.unwrap();
assert!(panic::catch_unwind(AssertUnwindSafe(|| fiber.resume(()))).is_err());
assert!(a.get());
struct SetOnDrop(Rc<Cell<bool>>);
impl Drop for SetOnDrop {
fn drop(&mut self) {
self.0.set(true);
}
}
}
#[test]
fn suspend_and_resume_values() {
let fiber = Fiber::new(fiber_stack(1024 * 1024), move |first, s| {
assert_eq!(first, 2.0);
assert_eq!(s.suspend(4), 3.0);
"hello".to_string()
})
.unwrap();
assert_eq!(fiber.resume(2.0), Err(4));
assert_eq!(fiber.resume(3.0), Ok("hello".to_string()));
}
}