use std::cell::UnsafeCell;
use std::future::Future;
use std::mem;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU8;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use intrusive_collections::intrusive_adapter;
use intrusive_collections::linked_list::LinkedList;
use intrusive_collections::linked_list::LinkedListOps;
use intrusive_collections::DefaultLinkOps;
use intrusive_collections::LinkOps;
use super::super::sync::SpinLock;
#[repr(align(128))]
pub struct AtomicLink {
prev: UnsafeCell<Option<NonNull<AtomicLink>>>,
next: UnsafeCell<Option<NonNull<AtomicLink>>>,
linked: AtomicBool,
}
impl AtomicLink {
fn new() -> AtomicLink {
AtomicLink {
linked: AtomicBool::new(false),
prev: UnsafeCell::new(None),
next: UnsafeCell::new(None),
}
}
fn is_linked(&self) -> bool {
self.linked.load(Ordering::Relaxed)
}
}
impl DefaultLinkOps for AtomicLink {
type Ops = AtomicLinkOps;
const NEW: Self::Ops = AtomicLinkOps;
}
unsafe impl Send for AtomicLink {}
unsafe impl Sync for AtomicLink {}
#[derive(Copy, Clone, Default)]
pub struct AtomicLinkOps;
#[allow(clippy::undocumented_unsafe_blocks)]
unsafe impl LinkOps for AtomicLinkOps {
type LinkPtr = NonNull<AtomicLink>;
unsafe fn acquire_link(&mut self, ptr: Self::LinkPtr) -> bool {
!ptr.as_ref().linked.swap(true, Ordering::Acquire)
}
unsafe fn release_link(&mut self, ptr: Self::LinkPtr) {
ptr.as_ref().linked.store(false, Ordering::Release)
}
}
#[allow(clippy::undocumented_unsafe_blocks)]
unsafe impl LinkedListOps for AtomicLinkOps {
unsafe fn next(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> {
*ptr.as_ref().next.get()
}
unsafe fn prev(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> {
*ptr.as_ref().prev.get()
}
unsafe fn set_next(&mut self, ptr: Self::LinkPtr, next: Option<Self::LinkPtr>) {
*ptr.as_ref().next.get() = next;
}
unsafe fn set_prev(&mut self, ptr: Self::LinkPtr, prev: Option<Self::LinkPtr>) {
*ptr.as_ref().prev.get() = prev;
}
}
#[derive(Clone, Copy)]
pub enum Kind {
Shared,
Exclusive,
}
enum State {
Init,
Waiting(Waker),
Woken,
Finished,
Processing,
}
#[repr(u8)]
#[derive(Debug, Eq, PartialEq)]
pub enum WaitingFor {
None = 0,
Mutex = 1,
Condvar = 2,
}
pub struct Waiter {
link: AtomicLink,
state: SpinLock<State>,
cancel: fn(usize, &Waiter, bool),
cancel_data: usize,
kind: Kind,
waiting_for: AtomicU8,
}
impl Waiter {
pub fn new(
kind: Kind,
cancel: fn(usize, &Waiter, bool),
cancel_data: usize,
waiting_for: WaitingFor,
) -> Waiter {
Waiter {
link: AtomicLink::new(),
state: SpinLock::new(State::Init),
cancel,
cancel_data,
kind,
waiting_for: AtomicU8::new(waiting_for as u8),
}
}
pub fn kind(&self) -> Kind {
self.kind
}
pub fn is_linked(&self) -> bool {
self.link.is_linked()
}
pub fn is_waiting_for(&self) -> WaitingFor {
match self.waiting_for.load(Ordering::Acquire) {
0 => WaitingFor::None,
1 => WaitingFor::Mutex,
2 => WaitingFor::Condvar,
v => panic!("Unknown value for `WaitingFor`: {v}"),
}
}
pub fn set_waiting_for(&self, waiting_for: WaitingFor) {
self.waiting_for.store(waiting_for as u8, Ordering::Release);
}
pub fn reset(&self, waiting_for: WaitingFor) {
debug_assert!(!self.is_linked(), "Cannot reset `Waiter` while linked");
self.set_waiting_for(waiting_for);
let mut state = self.state.lock();
if let State::Waiting(waker) = mem::replace(&mut *state, State::Init) {
mem::drop(state);
mem::drop(waker);
}
}
pub fn wait(&self) -> WaitFuture<'_> {
WaitFuture { waiter: self }
}
pub fn wake(&self) {
debug_assert!(!self.is_linked(), "Cannot wake `Waiter` while linked");
debug_assert_eq!(self.is_waiting_for(), WaitingFor::None);
let mut state = self.state.lock();
if let State::Waiting(waker) = mem::replace(&mut *state, State::Woken) {
mem::drop(state);
waker.wake();
}
}
}
pub struct WaitFuture<'w> {
waiter: &'w Waiter,
}
impl Future for WaitFuture<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut state = self.waiter.state.lock();
match mem::replace(&mut *state, State::Processing) {
State::Init => {
*state = State::Waiting(cx.waker().clone());
Poll::Pending
}
State::Waiting(old_waker) => {
*state = State::Waiting(cx.waker().clone());
mem::drop(state);
mem::drop(old_waker);
Poll::Pending
}
State::Woken => {
*state = State::Finished;
Poll::Ready(())
}
State::Finished => {
panic!("Future polled after returning Poll::Ready");
}
State::Processing => {
panic!("Unexpected waker state");
}
}
}
}
impl Drop for WaitFuture<'_> {
fn drop(&mut self) {
let state = self.waiter.state.lock();
match *state {
State::Finished => {}
State::Processing => panic!("Unexpected waker state"),
State::Woken => {
mem::drop(state);
(self.waiter.cancel)(self.waiter.cancel_data, self.waiter, true);
}
_ => {
mem::drop(state);
(self.waiter.cancel)(self.waiter.cancel_data, self.waiter, false);
}
}
}
}
intrusive_adapter!(pub WaiterAdapter = Arc<Waiter>: Waiter { link: AtomicLink });
pub type WaiterList = LinkedList<WaiterAdapter>;