Path: blob/main/crates/polars-stream/src/async_primitives/distributor_channel.rs
6939 views
use std::cell::UnsafeCell;1use std::mem::MaybeUninit;2use std::sync::Arc;3use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};45use crossbeam_utils::CachePadded;6use rand::prelude::*;78use super::task_parker::TaskParker;910/// Single-producer multi-consumer FIFO channel.11///12/// Each [`Receiver`] has an internal buffer of `bufsize`. Thus it is possible13/// that when one [`Sender`] is exhausted some other receivers still have data14/// available.15///16/// The FIFO order is only guaranteed per receiver. That is, each receiver is17/// guaranteed to see a subset of the data sent by the sender in the order the18/// sender sent it in, but not necessarily contiguously.19///20/// When one or more receivers are closed no attempt is made to avoid filling21/// those receivers' buffers. The values in the buffer of a closed receiver are22/// lost forever, they're not redistributed among the others, and simply23/// dropped when the channel is dropped.24pub fn distributor_channel<T>(25num_receivers: usize,26bufsize: usize,27) -> (Sender<T>, Vec<Receiver<T>>) {28let capacity = bufsize.next_power_of_two();29let receivers = (0..num_receivers)30.map(|_| {31CachePadded::new(ReceiverSlot {32closed: AtomicBool::new(false),33read_head: AtomicUsize::new(0),34parker: TaskParker::default(),35data: (0..capacity)36.map(|_| UnsafeCell::new(MaybeUninit::uninit()))37.collect(),38})39})40.collect();41let inner = Arc::new(DistributorInner {42send_closed: AtomicBool::new(false),43send_parker: TaskParker::default(),44write_heads: (0..num_receivers).map(|_| AtomicUsize::new(0)).collect(),45receivers,4647bufsize,48mask: capacity - 1,49});5051let receivers = (0..num_receivers)52.map(|index| Receiver {53inner: inner.clone(),54index,55})56.collect();5758let sender = Sender {59inner,60round_robin_idx: 0,61rng: SmallRng::from_rng(&mut rand::rng()),62};6364(sender, receivers)65}6667pub enum SendError<T> {68Full(T),69Closed(T),70}7172pub enum RecvError {73Empty,74Closed,75}7677struct ReceiverSlot<T> {78closed: AtomicBool,79read_head: AtomicUsize,80parker: TaskParker,81data: Box<[UnsafeCell<MaybeUninit<T>>]>,82}8384struct DistributorInner<T> {85send_closed: AtomicBool,86send_parker: TaskParker,87write_heads: Vec<AtomicUsize>,88receivers: Vec<CachePadded<ReceiverSlot<T>>>,8990bufsize: usize,91mask: usize,92}9394impl<T> DistributorInner<T> {95fn reduce_index(&self, idx: usize) -> usize {96idx & self.mask97}98}99100pub struct Sender<T> {101inner: Arc<DistributorInner<T>>,102round_robin_idx: usize,103rng: SmallRng,104}105106pub struct Receiver<T> {107inner: Arc<DistributorInner<T>>,108index: usize,109}110111unsafe impl<T: Send> Send for Sender<T> {}112unsafe impl<T: Send> Send for Receiver<T> {}113114impl<T: Send> Sender<T> {115/// Note: This intentionally takes `&mut` to ensure it is only accessed in a single-threaded116/// manner.117pub async fn send(&mut self, mut value: T) -> Result<(), T> {118let num_receivers = self.inner.receivers.len();119loop {120// Fast-path.121self.round_robin_idx += 1;122if self.round_robin_idx >= num_receivers {123self.round_robin_idx -= num_receivers;124}125126let mut hungriest_idx = self.round_robin_idx;127let mut shortest_len = self.upper_bound_len(self.round_robin_idx);128for _ in 0..4 {129let idx = ((self.rng.random::<u32>() as u64 * num_receivers as u64) >> 32) as usize;130let len = self.upper_bound_len(idx);131if len < shortest_len {132shortest_len = len;133hungriest_idx = idx;134}135}136137match unsafe { self.try_send(hungriest_idx, value) } {138Ok(()) => return Ok(()),139Err(SendError::Full(v)) => value = v,140Err(SendError::Closed(v)) => value = v,141}142143// Do one proper search before parking.144let park = self.inner.send_parker.park();145146// Try all receivers, starting at a random index.147let mut idx = ((self.rng.random::<u32>() as u64 * num_receivers as u64) >> 32) as usize;148let mut all_closed = true;149for _ in 0..num_receivers {150match unsafe { self.try_send(idx, value) } {151Ok(()) => return Ok(()),152Err(SendError::Full(v)) => {153all_closed = false;154value = v;155},156Err(SendError::Closed(v)) => value = v,157}158159idx += 1;160if idx >= num_receivers {161idx -= num_receivers;162}163}164165if all_closed {166return Err(value);167}168169park.await;170}171}172173// Returns the upper bound on the length of the queue of the given receiver.174// It is an upper bound because racy reads can reduce it in the meantime.175fn upper_bound_len(&self, recv_idx: usize) -> usize {176let read_head = self.inner.receivers[recv_idx]177.read_head178.load(Ordering::SeqCst);179let write_head = self.inner.write_heads[recv_idx].load(Ordering::Relaxed);180write_head.wrapping_sub(read_head)181}182183/// # Safety184/// May only be called from one thread at a time.185unsafe fn try_send(&self, recv_idx: usize, value: T) -> Result<(), SendError<T>> {186let read_head = self.inner.receivers[recv_idx]187.read_head188.load(Ordering::SeqCst);189let write_head = self.inner.write_heads[recv_idx].load(Ordering::Relaxed);190let len = write_head.wrapping_sub(read_head);191if len < self.inner.bufsize {192let idx = self.inner.reduce_index(write_head);193unsafe {194self.inner.receivers[recv_idx].data[idx]195.get()196.write(MaybeUninit::new(value));197self.inner.write_heads[recv_idx]198.store(write_head.wrapping_add(1), Ordering::SeqCst);199}200self.inner.receivers[recv_idx].parker.unpark();201Ok(())202} else if self.inner.receivers[recv_idx].closed.load(Ordering::SeqCst) {203Err(SendError::Closed(value))204} else {205Err(SendError::Full(value))206}207}208}209210impl<T: Send> Receiver<T> {211/// Note: This intentionally takes `&mut` to ensure it is only accessed in a single-threaded212/// manner.213pub async fn recv(&mut self) -> Result<T, ()> {214loop {215// Fast-path.216match unsafe { self.try_recv() } {217Ok(v) => return Ok(v),218Err(RecvError::Closed) => return Err(()),219Err(RecvError::Empty) => {},220}221222// Try again, threatening to park if there's still nothing.223let park = self.inner.receivers[self.index].parker.park();224match unsafe { self.try_recv() } {225Ok(v) => return Ok(v),226Err(RecvError::Closed) => return Err(()),227Err(RecvError::Empty) => {},228}229park.await;230}231}232233/// # Safety234/// May only be called from one thread at a time.235unsafe fn try_recv(&self) -> Result<T, RecvError> {236loop {237let read_head = self.inner.receivers[self.index]238.read_head239.load(Ordering::Relaxed);240let write_head = self.inner.write_heads[self.index].load(Ordering::SeqCst);241if read_head != write_head {242let idx = self.inner.reduce_index(read_head);243let read;244unsafe {245let ptr = self.inner.receivers[self.index].data[idx].get();246read = ptr.read().assume_init();247self.inner.receivers[self.index]248.read_head249.store(read_head.wrapping_add(1), Ordering::SeqCst);250}251self.inner.send_parker.unpark();252return Ok(read);253} else if self.inner.send_closed.load(Ordering::SeqCst) {254// Check write head again, sender could've sent something right255// before closing. We can do this relaxed because we'll read it256// again in the next iteration with SeqCst if it's a new value.257if write_head == self.inner.write_heads[self.index].load(Ordering::Relaxed) {258return Err(RecvError::Closed);259}260} else {261return Err(RecvError::Empty);262}263}264}265}266267impl<T> Drop for Sender<T> {268fn drop(&mut self) {269self.inner.send_closed.store(true, Ordering::SeqCst);270for recv in &self.inner.receivers {271recv.parker.unpark();272}273}274}275276impl<T> Drop for Receiver<T> {277fn drop(&mut self) {278self.inner.receivers[self.index]279.closed280.store(true, Ordering::SeqCst);281self.inner.send_parker.unpark();282}283}284285impl<T> Drop for DistributorInner<T> {286fn drop(&mut self) {287for r in 0..self.receivers.len() {288// We have exclusive access, so we only need to atomically load once.289let write_head = self.write_heads[r].load(Ordering::SeqCst);290let mut read_head = self.receivers[r].read_head.load(Ordering::Relaxed);291while read_head != write_head {292let idx = self.reduce_index(read_head);293unsafe {294(*self.receivers[r].data[idx].get()).assume_init_drop();295}296read_head = read_head.wrapping_add(1);297}298}299}300}301302303