Path: blob/main/crates/polars-stream/src/async_primitives/connector.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::cell::UnsafeCell;2use std::mem::MaybeUninit;3use std::pin::Pin;4use std::sync::Arc;5use std::sync::atomic::{AtomicU8, Ordering};6use std::task::{Context, Poll, Waker};78use atomic_waker::AtomicWaker;9use pin_project_lite::pin_project;1011/// Single-producer, single-consumer capacity-one channel.12pub fn connector<T>() -> (Sender<T>, Receiver<T>) {13let connector = Arc::new(Connector::default());14(15Sender {16connector: connector.clone(),17},18Receiver { connector },19)20}2122/*23For UnsafeCell safety, a sender may only set the FULL_BIT (giving exclusive24access to value to the receiver), and a receiver may only unset the FULL_BIT25(giving exclusive access back to the sender). Setting/clearing the FULL_BIT26must be done with a Release ordering, and before reading/writing the value27the FULL_BIT must be checked with an Acquire ordering.2829The exception is when the closed bit is set, at that point the unclosed30end has full exclusive access.31*/3233const FULL_BIT: u8 = 0b1;34const CLOSED_BIT: u8 = 0b10;35const WAITING_BIT: u8 = 0b100;3637#[repr(align(64))]38struct Connector<T> {39send_waker: AtomicWaker,40recv_waker: AtomicWaker,41value: UnsafeCell<MaybeUninit<T>>,42state: AtomicU8,43}4445impl<T> Default for Connector<T> {46fn default() -> Self {47Self {48send_waker: AtomicWaker::new(),49recv_waker: AtomicWaker::new(),50value: UnsafeCell::new(MaybeUninit::uninit()),51state: AtomicU8::new(0),52}53}54}5556pub enum SendError<T> {57Full(T),58Closed(T),59}6061pub enum RecvError {62Empty,63Closed,64}6566// SAFETY: all the send methods may only be called from a single sender at a67// time, and similarly for all the recv methods from a single receiver.68impl<T> Connector<T> {69unsafe fn poll_send(&self, value: &mut Option<T>, waker: &Waker) -> Poll<Result<(), T>> {70if let Some(v) = value.take() {71let mut state = self.state.load(Ordering::Acquire);72if state & FULL_BIT == FULL_BIT {73self.send_waker.register(waker);74let (Ok(s) | Err(s)) = self.state.compare_exchange(75state,76state | WAITING_BIT,77Ordering::Relaxed,78Ordering::Acquire, // Receiver updated, re-acquire.79);80state = s;81}8283match self.try_send_impl(v, state) {84Ok(()) => {},85Err(SendError::Closed(v)) => return Poll::Ready(Err(v)),86Err(SendError::Full(v)) => {87*value = Some(v);88return Poll::Pending;89},90}91}9293Poll::Ready(Ok(()))94}9596unsafe fn try_send_impl(&self, value: T, state: u8) -> Result<(), SendError<T>> {97if state & CLOSED_BIT == CLOSED_BIT {98return Err(SendError::Closed(value));99}100if state & FULL_BIT == FULL_BIT {101return Err(SendError::Full(value));102}103104unsafe {105self.value.get().write(MaybeUninit::new(value));106let state = self.state.swap(FULL_BIT, Ordering::Release);107if state & WAITING_BIT == WAITING_BIT {108self.recv_waker.wake();109}110if state & CLOSED_BIT == CLOSED_BIT {111// SAFETY: no synchronization needed, we are the only one left.112// Restore the closed bit we just overwrote.113self.state.store(CLOSED_BIT, Ordering::Relaxed);114return Err(SendError::Closed(self.value.get().read().assume_init()));115}116}117118Ok(())119}120121unsafe fn poll_recv(&self, waker: &Waker) -> Poll<Result<T, ()>> {122let mut state = self.state.load(Ordering::Acquire);123if state & FULL_BIT == 0 {124self.recv_waker.register(waker);125let (Ok(s) | Err(s)) = self.state.compare_exchange(126state,127state | WAITING_BIT,128Ordering::Relaxed,129Ordering::Acquire, // Sender updated, re-acquire.130);131state = s;132}133134match self.try_recv_impl(state) {135Ok(v) => Poll::Ready(Ok(v)),136Err(RecvError::Empty) => Poll::Pending,137Err(RecvError::Closed) => Poll::Ready(Err(())),138}139}140141unsafe fn try_recv_impl(&self, state: u8) -> Result<T, RecvError> {142if state & FULL_BIT == FULL_BIT {143unsafe {144let ret = self.value.get().read().assume_init();145let state = self.state.swap(0, Ordering::Release);146if state & WAITING_BIT == WAITING_BIT {147self.send_waker.wake();148}149if state & CLOSED_BIT == CLOSED_BIT {150// Restore the closed bit we just overwrote.151self.state.store(CLOSED_BIT, Ordering::Relaxed);152}153return Ok(ret);154}155}156157// Check closed bit last so we do receive any last element sent before158// closing sender.159if state & CLOSED_BIT == CLOSED_BIT {160return Err(RecvError::Closed);161}162163Err(RecvError::Empty)164}165166unsafe fn try_send(&self, value: T) -> Result<(), SendError<T>> {167self.try_send_impl(value, self.state.load(Ordering::Acquire))168}169170unsafe fn try_recv(&self) -> Result<T, RecvError> {171self.try_recv_impl(self.state.load(Ordering::Acquire))172}173174/// # Safety175/// You may not access this connector anymore as a sender after this call.176unsafe fn close_send(&self) {177self.state.fetch_or(CLOSED_BIT, Ordering::Relaxed);178self.recv_waker.wake();179}180181/// # Safety182/// You may not access this connector anymore as a receiver after this call.183unsafe fn close_recv(&self) {184let state = self.state.fetch_or(CLOSED_BIT, Ordering::Acquire);185drop(self.try_recv_impl(state));186self.send_waker.wake();187}188}189190pub struct Sender<T> {191connector: Arc<Connector<T>>,192}193194unsafe impl<T: Send> Send for Sender<T> {}195196impl<T> Drop for Sender<T> {197fn drop(&mut self) {198unsafe { self.connector.close_send() }199}200}201202pub struct Receiver<T> {203connector: Arc<Connector<T>>,204}205206unsafe impl<T: Send> Send for Receiver<T> {}207208impl<T> Drop for Receiver<T> {209fn drop(&mut self) {210unsafe { self.connector.close_recv() }211}212}213214pin_project! {215pub struct SendFuture<'a, T> {216connector: &'a Connector<T>,217value: Option<T>,218}219}220221unsafe impl<T: Send> Send for SendFuture<'_, T> {}222223impl<T: Send> Sender<T> {224/// Returns a future that when awaited will send the value to the [`Receiver`].225/// Returns Err(value) if the connector is closed.226#[must_use]227pub fn send(&mut self, value: T) -> SendFuture<'_, T> {228SendFuture {229connector: &self.connector,230value: Some(value),231}232}233234#[allow(unused)]235pub fn try_send(&mut self, value: T) -> Result<(), SendError<T>> {236unsafe { self.connector.try_send(value) }237}238}239240impl<T> std::future::Future for SendFuture<'_, T> {241type Output = Result<(), T>;242243fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {244assert!(245self.value.is_some(),246"re-poll after Poll::Ready in connector SendFuture"247);248unsafe { self.connector.poll_send(self.project().value, cx.waker()) }249}250}251252pin_project! {253pub struct RecvFuture<'a, T> {254connector: &'a Connector<T>,255done: bool,256}257}258259unsafe impl<T: Send> Send for RecvFuture<'_, T> {}260261impl<T: Send> Receiver<T> {262/// Returns a future that when awaited will return `Ok(value)` once the263/// value is received, or returns `Err(())` if the [`Sender`] was dropped264/// before sending a value.265#[must_use]266pub fn recv(&mut self) -> RecvFuture<'_, T> {267RecvFuture {268connector: &self.connector,269done: false,270}271}272273#[allow(unused)]274pub fn try_recv(&mut self) -> Result<T, RecvError> {275unsafe { self.connector.try_recv() }276}277}278279impl<T> std::future::Future for RecvFuture<'_, T> {280type Output = Result<T, ()>;281282fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {283assert!(284!self.done,285"re-poll after Poll::Ready in connector SendFuture"286);287unsafe { self.connector.poll_recv(cx.waker()) }288}289}290291292