Path: blob/main/crates/polars-stream/src/async_primitives/linearizer.rs
6939 views
use std::collections::BinaryHeap;12use tokio::sync::mpsc::{Receiver, Sender, channel};34/// Stores the state for which inserter we need to poll.5enum PollState {6NoPoll,7Poll(usize),8PollAll,9}1011struct LinearedItem<T> {12value: T,13sender_id: usize,14}1516impl<T: Ord> PartialEq for LinearedItem<T> {17fn eq(&self, other: &Self) -> bool {18self.value.eq(&other.value)19}20}21impl<T: Ord> Eq for LinearedItem<T> {}22impl<T: Ord> PartialOrd for LinearedItem<T> {23#[allow(clippy::non_canonical_partial_ord_impl)]24fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {25Some(self.value.cmp(&other.value))26}27}28impl<T: Ord> Ord for LinearedItem<T> {29fn cmp(&self, other: &Self) -> std::cmp::Ordering {30self.value.cmp(&other.value)31}32}3334/// Utility to convert the input of `N` senders of ordered data into `1` stream of ordered data.35pub struct Linearizer<T> {36receivers: Vec<Receiver<T>>,37poll_state: PollState,3839heap: BinaryHeap<LinearedItem<T>>,40}4142impl<T: Ord> Linearizer<T> {43pub fn new(num_inserters: usize, buffer_size: usize) -> (Self, Vec<Inserter<T>>) {44let mut receivers = Vec::with_capacity(num_inserters);45let mut inserters = Vec::with_capacity(num_inserters);4647for _ in 0..num_inserters {48// We could perhaps use a bespoke spsc bounded channel here in the49// future, instead of tokio's mpsc channel.50let (sender, receiver) = channel(buffer_size);51receivers.push(receiver);52inserters.push(Inserter { sender });53}54let slf = Self {55receivers,56poll_state: PollState::PollAll,57heap: BinaryHeap::with_capacity(num_inserters),58};59(slf, inserters)60}6162pub fn new_with_maintain_order(63num_inserters: usize,64buffer_size: usize,65maintain_order: bool,66) -> (Self, Vec<Inserter<T>>) {67if maintain_order {68return Self::new(num_inserters, buffer_size);69}7071let (sender, receiver) = channel(buffer_size * num_inserters);72let receivers = vec![receiver];73let inserters = (0..num_inserters)74.map(|_| Inserter {75sender: sender.clone(),76})77.collect();7879let slf = Self {80receivers,81poll_state: PollState::PollAll,82heap: BinaryHeap::with_capacity(1),83};84(slf, inserters)85}8687/// Fetch the next ordered item produced by senders.88///89/// This may wait for at each sender to have sent at least one value before the [`Linearizer`]90/// starts producing.91///92/// If all senders have closed their channels and there are no more buffered values, this93/// returns `None`.94pub async fn get(&mut self) -> Option<T> {95// The idea is that we have exactly one value per inserter in the96// binary heap, and when we take one out we must refill it. This way we97// always ensure we have the value with the highest global order.98let poll_range = match self.poll_state {99PollState::NoPoll => 0..0,100PollState::Poll(i) => i..i + 1,101PollState::PollAll => 0..self.receivers.len(),102};103104for sender_id in poll_range {105// If no value was received from that particular inserter, that106// stream is done and thus we no longer need to consider it for the107// global order.108if let Some(value) = self.receivers[sender_id].recv().await {109self.heap.push(LinearedItem { value, sender_id });110}111}112113if let Some(first_in_merged_streams) = self.heap.pop() {114let LinearedItem { value, sender_id } = first_in_merged_streams;115self.poll_state = PollState::Poll(sender_id);116Some(value)117} else {118self.poll_state = PollState::NoPoll;119None120}121}122}123124pub struct Inserter<T> {125sender: Sender<T>,126}127128impl<T: Ord> Inserter<T> {129pub async fn insert(&mut self, value: T) -> Result<(), T> {130self.sender.send(value).await.map_err(|e| e.0)131}132}133134135