Path: blob/main/crates/polars-stream/src/async_primitives/connector.rs
8424 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::cell::UnsafeCell;2use std::mem::MaybeUninit;3use std::pin::Pin;4use std::sync::Arc;5use std::sync::atomic::{AtomicU8, Ordering};6use std::task::{Context, Poll, Waker};78use atomic_waker::AtomicWaker;9use pin_project_lite::pin_project;1011pub type Sender<T> = SenderExt<T, ()>;12pub type Receiver<T> = ReceiverExt<T, ()>;1314/// Single-producer, single-consumer capacity-one channel.15pub fn connector<T>() -> (Sender<T>, Receiver<T>) {16let connector = Arc::new(Connector::new(()));17(18Sender {19connector: connector.clone(),20},21Receiver { connector },22)23}2425/// Single-producer, single-consumer capacity-one channel, with a shared common26/// value.27pub fn connector_with<T, S>(shared: S) -> (SenderExt<T, S>, ReceiverExt<T, S>) {28let connector = Arc::new(Connector::new(shared));29(30SenderExt {31connector: connector.clone(),32},33ReceiverExt { connector },34)35}3637/*38For UnsafeCell safety, a sender may only set the FULL_BIT (giving exclusive39access to value to the receiver), and a receiver may only unset the FULL_BIT40(giving exclusive access back to the sender). Setting/clearing the FULL_BIT41must be done with a Release ordering, and before reading/writing the value42the FULL_BIT must be checked with an Acquire ordering.4344The exception is when the closed bit is set, at that point the unclosed45end has full exclusive access.46*/4748const FULL_BIT: u8 = 0b1;49const CLOSED_BIT: u8 = 0b10;50const WAITING_BIT: u8 = 0b100;5152#[repr(align(128))]53struct Connector<T, S> {54send_waker: AtomicWaker,55recv_waker: AtomicWaker,56value: UnsafeCell<MaybeUninit<T>>,57state: AtomicU8,58shared: S,59}6061impl<T, S> Connector<T, S> {62fn new(shared: S) -> Self {63Self {64send_waker: AtomicWaker::new(),65recv_waker: AtomicWaker::new(),66value: UnsafeCell::new(MaybeUninit::uninit()),67state: AtomicU8::new(0),68shared,69}70}71}7273pub enum SendError<T> {74Full(T),75Closed(T),76}7778pub enum RecvError {79Empty,80Closed,81}8283// SAFETY: all the send methods may only be called from a single sender at a84// time, and similarly for all the recv methods from a single receiver.85impl<T, S> Connector<T, S> {86unsafe fn poll_send(&self, value: &mut Option<T>, waker: &Waker) -> Poll<Result<(), T>> {87if let Some(v) = value.take() {88let mut state = self.state.load(Ordering::Acquire);89if state & FULL_BIT == FULL_BIT {90self.send_waker.register(waker);91let (Ok(s) | Err(s)) = self.state.compare_exchange(92state,93state | WAITING_BIT,94Ordering::Relaxed,95Ordering::Acquire, // Receiver updated, re-acquire.96);97state = s;98}99100match self.try_send_impl(v, state) {101Ok(()) => {},102Err(SendError::Closed(v)) => return Poll::Ready(Err(v)),103Err(SendError::Full(v)) => {104*value = Some(v);105return Poll::Pending;106},107}108}109110Poll::Ready(Ok(()))111}112113unsafe fn try_send_impl(&self, value: T, state: u8) -> Result<(), SendError<T>> {114if state & CLOSED_BIT == CLOSED_BIT {115return Err(SendError::Closed(value));116}117if state & FULL_BIT == FULL_BIT {118return Err(SendError::Full(value));119}120121unsafe {122self.value.get().write(MaybeUninit::new(value));123let state = self.state.swap(FULL_BIT, Ordering::Release);124if state & WAITING_BIT == WAITING_BIT {125self.recv_waker.wake();126}127if state & CLOSED_BIT == CLOSED_BIT {128// SAFETY: no synchronization needed, we are the only one left.129// Restore the closed bit we just overwrote.130self.state.store(CLOSED_BIT, Ordering::Relaxed);131return Err(SendError::Closed(self.value.get().read().assume_init()));132}133}134135Ok(())136}137138unsafe fn poll_recv(&self, waker: &Waker) -> Poll<Result<T, ()>> {139let mut state = self.state.load(Ordering::Acquire);140if state & FULL_BIT == 0 {141self.recv_waker.register(waker);142let (Ok(s) | Err(s)) = self.state.compare_exchange(143state,144state | WAITING_BIT,145Ordering::Relaxed,146Ordering::Acquire, // Sender updated, re-acquire.147);148state = s;149}150151match self.try_recv_impl(state) {152Ok(v) => Poll::Ready(Ok(v)),153Err(RecvError::Empty) => Poll::Pending,154Err(RecvError::Closed) => Poll::Ready(Err(())),155}156}157158unsafe fn try_recv_impl(&self, state: u8) -> Result<T, RecvError> {159if state & FULL_BIT == FULL_BIT {160unsafe {161let ret = self.value.get().read().assume_init();162let state = self.state.swap(0, Ordering::Release);163if state & WAITING_BIT == WAITING_BIT {164self.send_waker.wake();165}166if state & CLOSED_BIT == CLOSED_BIT {167// Restore the closed bit we just overwrote.168self.state.store(CLOSED_BIT, Ordering::Relaxed);169}170return Ok(ret);171}172}173174// Check closed bit last so we do receive any last element sent before175// closing sender.176if state & CLOSED_BIT == CLOSED_BIT {177return Err(RecvError::Closed);178}179180Err(RecvError::Empty)181}182183unsafe fn try_send(&self, value: T) -> Result<(), SendError<T>> {184self.try_send_impl(value, self.state.load(Ordering::Acquire))185}186187unsafe fn try_recv(&self) -> Result<T, RecvError> {188self.try_recv_impl(self.state.load(Ordering::Acquire))189}190191/// # Safety192/// You may not access this connector anymore as a sender after this call.193unsafe fn close_send(&self) {194self.state.fetch_or(CLOSED_BIT, Ordering::Relaxed);195self.recv_waker.wake();196}197198/// # Safety199/// You may not access this connector anymore as a receiver after this call.200unsafe fn close_recv(&self) {201let state = self.state.fetch_or(CLOSED_BIT, Ordering::Acquire);202drop(self.try_recv_impl(state));203self.send_waker.wake();204}205}206207pub struct SenderExt<T, S> {208connector: Arc<Connector<T, S>>,209}210211unsafe impl<T: Send, S: Sync> Send for SenderExt<T, S> {}212213impl<T, S> Drop for SenderExt<T, S> {214fn drop(&mut self) {215unsafe { self.connector.close_send() }216}217}218219pub struct ReceiverExt<T, S> {220connector: Arc<Connector<T, S>>,221}222223unsafe impl<T: Send, S: Sync> Send for ReceiverExt<T, S> {}224225impl<T, S> Drop for ReceiverExt<T, S> {226fn drop(&mut self) {227unsafe { self.connector.close_recv() }228}229}230231pin_project! {232pub struct SendFuture<'a, T, S> {233connector: &'a Connector<T, S>,234value: Option<T>,235}236}237238unsafe impl<T: Send, S: Sync> Send for SendFuture<'_, T, S> {}239240impl<T: Send, S: Sync> SenderExt<T, S> {241/// Returns a future that when awaited will send the value to the [`ReceiverExt`].242/// Returns Err(value) if the connector is closed.243#[must_use]244pub fn send(&mut self, value: T) -> SendFuture<'_, T, S> {245SendFuture {246connector: &self.connector,247value: Some(value),248}249}250251#[allow(unused)]252pub fn try_send(&mut self, value: T) -> Result<(), SendError<T>> {253unsafe { self.connector.try_send(value) }254}255256pub fn shared(&self) -> &S {257&self.connector.shared258}259}260261impl<T, S> std::future::Future for SendFuture<'_, T, S> {262type Output = Result<(), T>;263264fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {265assert!(266self.value.is_some(),267"re-poll after Poll::Ready in connector SendFuture"268);269unsafe { self.connector.poll_send(self.project().value, cx.waker()) }270}271}272273pin_project! {274pub struct RecvFuture<'a, T, S> {275connector: &'a Connector<T, S>,276done: bool,277}278}279280unsafe impl<T: Send, S: Sync> Send for RecvFuture<'_, T, S> {}281282impl<T: Send, S: Sync> ReceiverExt<T, S> {283/// Returns a future that when awaited will return `Ok(value)` once the284/// value is received, or returns `Err(())` if the [`SenderExt`] was dropped285/// before sending a value.286#[must_use]287pub fn recv(&mut self) -> RecvFuture<'_, T, S> {288RecvFuture {289connector: &self.connector,290done: false,291}292}293294#[allow(unused)]295pub fn try_recv(&mut self) -> Result<T, RecvError> {296unsafe { self.connector.try_recv() }297}298299pub fn shared(&self) -> &S {300&self.connector.shared301}302}303304impl<T, S> std::future::Future for RecvFuture<'_, T, S> {305type Output = Result<T, ()>;306307fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {308assert!(309!self.done,310"re-poll after Poll::Ready in connector SendFuture"311);312unsafe { self.connector.poll_recv(cx.waker()) }313}314}315316317