use crate::stackswitch::*;
use crate::{RunResult, RuntimeFiberStack};
use std::boxed::Box;
use std::cell::Cell;
use std::io;
use std::ops::Range;
use std::ptr;
use std::sync::atomic::{AtomicUsize, Ordering};
pub type Error = io::Error;
pub struct FiberStack {
base: BasePtr,
len: usize,
storage: FiberStackStorage,
}
struct BasePtr(*mut u8);
unsafe impl Send for BasePtr {}
unsafe impl Sync for BasePtr {}
enum FiberStackStorage {
Mmap(MmapFiberStack),
Unmanaged(usize),
Custom(Box<dyn RuntimeFiberStack>),
}
fn host_page_size() -> usize {
static PAGE_SIZE: AtomicUsize = AtomicUsize::new(0);
return match PAGE_SIZE.load(Ordering::Relaxed) {
0 => {
let size = unsafe { libc::sysconf(libc::_SC_PAGESIZE).try_into().unwrap() };
assert!(size != 0);
PAGE_SIZE.store(size, Ordering::Relaxed);
size
}
n => n,
};
}
impl FiberStack {
pub fn new(size: usize, zeroed: bool) -> io::Result<Self> {
let page_size = host_page_size();
let _ = zeroed;
if cfg!(asan) {
return Self::from_custom(asan::new_fiber_stack(size)?);
}
let stack = MmapFiberStack::new(size)?;
Ok(FiberStack {
base: BasePtr(stack.mapping_base.wrapping_byte_add(page_size)),
len: stack.mapping_len - page_size,
storage: FiberStackStorage::Mmap(stack),
})
}
pub unsafe fn from_raw_parts(base: *mut u8, guard_size: usize, len: usize) -> io::Result<Self> {
if cfg!(asan) {
return Self::from_custom(asan::new_fiber_stack(len)?);
}
Ok(FiberStack {
base: BasePtr(unsafe { base.add(guard_size) }),
len,
storage: FiberStackStorage::Unmanaged(guard_size),
})
}
pub fn is_from_raw_parts(&self) -> bool {
matches!(self.storage, FiberStackStorage::Unmanaged(_))
}
pub fn from_custom(custom: Box<dyn RuntimeFiberStack>) -> io::Result<Self> {
let range = custom.range();
let page_size = host_page_size();
let start_ptr = range.start as *mut u8;
assert!(
start_ptr.align_offset(page_size) == 0,
"expected fiber stack base ({start_ptr:?}) to be page aligned ({page_size:#x})",
);
let end_ptr = range.end as *const u8;
assert!(
end_ptr.align_offset(page_size) == 0,
"expected fiber stack end ({end_ptr:?}) to be page aligned ({page_size:#x})",
);
Ok(FiberStack {
base: BasePtr(start_ptr),
len: range.len(),
storage: FiberStackStorage::Custom(custom),
})
}
pub fn top(&self) -> Option<*mut u8> {
Some(self.base.0.wrapping_byte_add(self.len))
}
pub fn range(&self) -> Option<Range<usize>> {
let base = self.base.0 as usize;
Some(base..base + self.len)
}
pub fn guard_range(&self) -> Option<Range<*mut u8>> {
match &self.storage {
FiberStackStorage::Unmanaged(guard_size) => unsafe {
let start = self.base.0.sub(*guard_size);
Some(start..self.base.0)
},
FiberStackStorage::Mmap(mmap) => Some(mmap.mapping_base..self.base.0),
FiberStackStorage::Custom(custom) => Some(custom.guard_range()),
}
}
}
struct MmapFiberStack {
mapping_base: *mut u8,
mapping_len: usize,
}
unsafe impl Send for MmapFiberStack {}
unsafe impl Sync for MmapFiberStack {}
impl MmapFiberStack {
fn new(size: usize) -> io::Result<Self> {
let page_size = host_page_size();
let size = if size == 0 {
page_size
} else {
(size + (page_size - 1)) & (!(page_size - 1))
};
unsafe {
let mmap_len = size + page_size;
let mmap = rustix::mm::mmap_anonymous(
ptr::null_mut(),
mmap_len,
rustix::mm::ProtFlags::empty(),
rustix::mm::MapFlags::PRIVATE,
)?;
rustix::mm::mprotect(
mmap.byte_add(page_size),
size,
rustix::mm::MprotectFlags::READ | rustix::mm::MprotectFlags::WRITE,
)?;
Ok(MmapFiberStack {
mapping_base: mmap.cast(),
mapping_len: mmap_len,
})
}
}
}
impl Drop for MmapFiberStack {
fn drop(&mut self) {
unsafe {
let ret = rustix::mm::munmap(self.mapping_base.cast(), self.mapping_len);
debug_assert!(ret.is_ok());
}
}
}
pub struct Fiber;
pub struct Suspend {
top_of_stack: *mut u8,
previous: asan::PreviousStack,
}
extern "C" fn fiber_start<F, A, B, C>(arg0: *mut u8, top_of_stack: *mut u8)
where
F: FnOnce(A, &mut super::Suspend<A, B, C>) -> C,
{
unsafe {
let previous = asan::fiber_start_complete();
let inner = Suspend {
top_of_stack,
previous,
};
let initial = inner.take_resume::<A, B, C>();
super::Suspend::<A, B, C>::execute(inner, initial, Box::from_raw(arg0.cast::<F>()))
}
}
impl Fiber {
pub fn new<F, A, B, C>(stack: &FiberStack, func: F) -> io::Result<Self>
where
F: FnOnce(A, &mut super::Suspend<A, B, C>) -> C,
{
if !SUPPORTED_ARCH {
return Err(io::Error::new(
io::ErrorKind::Other,
"fibers not supported on this host architecture",
));
}
unsafe {
let data = Box::into_raw(Box::new(func)).cast();
wasmtime_fiber_init(stack.top().unwrap(), fiber_start::<F, A, B, C>, data);
}
Ok(Self)
}
pub(crate) fn resume<A, B, C>(&self, stack: &FiberStack, result: &Cell<RunResult<A, B, C>>) {
unsafe {
let addr = stack.top().unwrap().cast::<usize>().offset(-1);
addr.write(result as *const _ as usize);
asan::fiber_switch(
stack.top().unwrap(),
false,
&mut asan::PreviousStack::new(stack),
);
addr.write(0);
}
}
pub(crate) unsafe fn drop<A, B, C>(&mut self) {}
}
impl Suspend {
pub(crate) fn switch<A, B, C>(&mut self, result: RunResult<A, B, C>) -> A {
unsafe {
let is_finishing = match &result {
RunResult::Returned(_) | RunResult::Panicked(_) => true,
RunResult::Executing | RunResult::Resuming(_) | RunResult::Yield(_) => false,
};
(*self.result_location::<A, B, C>()).set(result);
asan::fiber_switch(self.top_of_stack, is_finishing, &mut self.previous);
self.take_resume::<A, B, C>()
}
}
pub(crate) fn exit<A, B, C>(&mut self, result: RunResult<A, B, C>) {
self.switch(result);
unreachable!()
}
unsafe fn take_resume<A, B, C>(&self) -> A {
unsafe {
match (*self.result_location::<A, B, C>()).replace(RunResult::Executing) {
RunResult::Resuming(val) => val,
_ => panic!("not in resuming state"),
}
}
}
unsafe fn result_location<A, B, C>(&self) -> *const Cell<RunResult<A, B, C>> {
unsafe {
let ret = self.top_of_stack.cast::<*const u8>().offset(-1).read();
assert!(!ret.is_null());
ret.cast()
}
}
}
#[cfg(asan)]
mod asan {
use super::{FiberStack, MmapFiberStack, RuntimeFiberStack, host_page_size};
use alloc::boxed::Box;
use alloc::vec::Vec;
use std::mem::ManuallyDrop;
use std::ops::Range;
use std::sync::Mutex;
pub struct PreviousStack {
bottom: *const u8,
size: usize,
}
impl PreviousStack {
pub fn new(stack: &FiberStack) -> PreviousStack {
let range = stack.range().unwrap();
PreviousStack {
bottom: range.start as *const u8,
size: range.len() - 2 * std::mem::size_of::<*const u8>(),
}
}
}
impl Default for PreviousStack {
fn default() -> PreviousStack {
PreviousStack {
bottom: std::ptr::null(),
size: 0,
}
}
}
pub unsafe fn fiber_switch(
top_of_stack: *mut u8,
is_finishing: bool,
prev: &mut PreviousStack,
) {
assert!(super::SUPPORTED_ARCH);
let mut private_asan_pointer = std::ptr::null_mut();
let private_asan_pointer_ref = if is_finishing {
None
} else {
Some(&mut private_asan_pointer)
};
unsafe {
__sanitizer_start_switch_fiber(private_asan_pointer_ref, prev.bottom, prev.size);
super::wasmtime_fiber_switch(top_of_stack);
__sanitizer_finish_switch_fiber(private_asan_pointer, &mut prev.bottom, &mut prev.size);
}
}
pub unsafe fn fiber_start_complete() -> PreviousStack {
let mut ret = PreviousStack::default();
unsafe {
__sanitizer_finish_switch_fiber(std::ptr::null_mut(), &mut ret.bottom, &mut ret.size);
}
ret
}
unsafe extern "C" {
fn __sanitizer_start_switch_fiber(
private_asan_pointer_save: Option<&mut *mut u8>,
bottom: *const u8,
size: usize,
);
fn __sanitizer_finish_switch_fiber(
private_asan_pointer: *mut u8,
bottom_old: &mut *const u8,
size_old: &mut usize,
);
}
static FIBER_STACKS: Mutex<Vec<MmapFiberStack>> = Mutex::new(Vec::new());
pub fn new_fiber_stack(size: usize) -> std::io::Result<Box<dyn RuntimeFiberStack>> {
let page_size = host_page_size();
let needed_size = size + page_size;
let mut stacks = FIBER_STACKS.lock().unwrap();
let stack = match stacks.iter().position(|i| needed_size <= i.mapping_len) {
Some(i) => stacks.remove(i),
None => MmapFiberStack::new(size)?,
};
let stack = AsanFiberStack {
mmap: ManuallyDrop::new(stack),
};
Ok(Box::new(stack))
}
struct AsanFiberStack {
mmap: ManuallyDrop<MmapFiberStack>,
}
unsafe impl RuntimeFiberStack for AsanFiberStack {
fn top(&self) -> *mut u8 {
self.mmap
.mapping_base
.wrapping_byte_add(self.mmap.mapping_len)
}
fn range(&self) -> Range<usize> {
let base = self.mmap.mapping_base as usize;
let end = base + self.mmap.mapping_len;
base + host_page_size()..end
}
fn guard_range(&self) -> Range<*mut u8> {
self.mmap.mapping_base..self.mmap.mapping_base.wrapping_add(host_page_size())
}
}
impl Drop for AsanFiberStack {
fn drop(&mut self) {
let stack = unsafe { ManuallyDrop::take(&mut self.mmap) };
FIBER_STACKS.lock().unwrap().push(stack);
}
}
}
#[cfg(not(asan))]
mod asan_disabled {
use super::{FiberStack, RuntimeFiberStack};
use std::boxed::Box;
#[derive(Default)]
pub struct PreviousStack;
impl PreviousStack {
#[inline]
pub fn new(_stack: &FiberStack) -> PreviousStack {
PreviousStack
}
}
pub unsafe fn fiber_switch(
top_of_stack: *mut u8,
_is_finishing: bool,
_prev: &mut PreviousStack,
) {
assert!(super::SUPPORTED_ARCH);
unsafe {
super::wasmtime_fiber_switch(top_of_stack);
}
}
#[inline]
pub unsafe fn fiber_start_complete() -> PreviousStack {
PreviousStack
}
pub fn new_fiber_stack(_size: usize) -> std::io::Result<Box<dyn RuntimeFiberStack>> {
unimplemented!()
}
}
#[cfg(not(asan))]
use asan_disabled as asan;