Path: blob/main/crates/polars-stream/src/nodes/merge_sorted.rs
6939 views
use std::collections::VecDeque;12use polars_core::prelude::ChunkCompareIneq;3use polars_ops::frame::_merge_sorted_dfs;45use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;6use crate::async_primitives::connector::Receiver;7use crate::async_primitives::distributor_channel::distributor_channel;8use crate::morsel::{SourceToken, get_ideal_morsel_size};9use crate::nodes::compute_node_prelude::*;1011/// Performs `merge_sorted` with the last column being regarded as the key column. This key column12/// is also popped in the send pipe.13pub struct MergeSortedNode {14seq: MorselSeq,1516starting_nulls: bool,1718// Not yet merged buffers.19left_unmerged: VecDeque<DataFrame>,20right_unmerged: VecDeque<DataFrame>,21}2223impl MergeSortedNode {24pub fn new() -> Self {25Self {26seq: MorselSeq::default(),2728starting_nulls: false,2930left_unmerged: VecDeque::new(),31right_unmerged: VecDeque::new(),32}33}34}3536/// Find a part amongst both unmerged buffers which is mergeable.37///38/// This returns `None` if there is nothing mergeable at this point.39fn find_mergeable(40left_unmerged: &mut VecDeque<DataFrame>,41right_unmerged: &mut VecDeque<DataFrame>,4243is_first: bool,44starting_nulls: &mut bool,45) -> PolarsResult<Option<(DataFrame, DataFrame)>> {46fn first_non_empty(vd: &mut VecDeque<DataFrame>) -> Option<DataFrame> {47let mut df = vd.pop_front()?;48while df.height() == 0 {49df = vd.pop_front()?;50}51Some(df)52}5354loop {55let (mut left, mut right) = match (56first_non_empty(left_unmerged),57first_non_empty(right_unmerged),58) {59(Some(l), Some(r)) => (l, r),60(Some(l), None) => {61left_unmerged.push_front(l);62return Ok(None);63},64(None, Some(r)) => {65right_unmerged.push_front(r);66return Ok(None);67},68(None, None) => return Ok(None),69};7071let left_key = left.get_columns().last().unwrap();72let right_key = right.get_columns().last().unwrap();7374let left_null_count = left_key.null_count();75let right_null_count = right_key.null_count();7677let has_nulls = left_null_count > 0 || right_null_count > 0;7879// If we are on the first morsel we need to decide whether we have80// nulls first or not.81if is_first82&& has_nulls83&& (left_key.head(Some(1)).has_nulls() || right_key.head(Some(1)).has_nulls())84{85*starting_nulls = true;86}8788// For both left and right, find row index of the minimum of the maxima89// of the left and right key columns. We can safely merge until this90// point.91let mut left_cutoff = left.height();92let mut right_cutoff = right.height();9394let left_key_last = left_key.tail(Some(1));95let right_key_last = right_key.tail(Some(1));9697// We already made sure we had data to work with.98assert!(!left_key_last.is_empty());99assert!(!right_key_last.is_empty());100101if has_nulls {102if *starting_nulls {103// If there are starting nulls do those first, then repeat104// without the nulls.105left_cutoff = left_null_count;106right_cutoff = right_null_count;107} else {108// If there are ending nulls then first do things without the109// nulls and then repeat with only the nulls the nulls.110let left_is_all_nulls = left_null_count == left.height();111let right_is_all_nulls = right_null_count == right.height();112113match (left_is_all_nulls, right_is_all_nulls) {114(false, false) => {115let left_nulls;116let right_nulls;117(left, left_nulls) =118left.split_at((left.height() - left_null_count) as i64);119(right, right_nulls) =120right.split_at((right.height() - right_null_count) as i64);121122left_unmerged.push_front(left_nulls);123left_unmerged.push_front(left);124right_unmerged.push_front(right_nulls);125right_unmerged.push_front(right);126continue;127},128(true, false) => left_cutoff = 0,129(false, true) => right_cutoff = 0,130(true, true) => {},131}132}133} else if left_key_last.lt(&right_key_last)?.all() {134// @TODO: This is essentially search sorted, but that does not135// support categoricals at moment.136let gt_mask = right_key.gt(&left_key_last)?;137right_cutoff = gt_mask.downcast_as_array().values().leading_zeros();138} else if left_key_last.gt(&right_key_last)?.all() {139// @TODO: This is essentially search sorted, but that does not140// support categoricals at moment.141let gt_mask = left_key.gt(&right_key_last)?;142left_cutoff = gt_mask.downcast_as_array().values().leading_zeros();143}144145let left_mergeable: DataFrame;146let right_mergeable: DataFrame;147(left_mergeable, left) = left.split_at(left_cutoff as i64);148(right_mergeable, right) = right.split_at(right_cutoff as i64);149150if !left.is_empty() {151left_unmerged.push_front(left);152}153if !right.is_empty() {154right_unmerged.push_front(right);155}156157return Ok(Some((left_mergeable, right_mergeable)));158}159}160161fn remove_key_column(df: &mut DataFrame) {162// SAFETY:163// - We only pop so height stays same.164// - We only pop so no new name collisions.165// - We clear schema afterwards.166unsafe { df.get_columns_mut().pop().unwrap() };167df.clear_schema();168}169170impl ComputeNode for MergeSortedNode {171fn name(&self) -> &str {172"merge-sorted"173}174175fn update_state(176&mut self,177recv: &mut [PortState],178send: &mut [PortState],179_state: &StreamingExecutionState,180) -> PolarsResult<()> {181assert_eq!(send.len(), 1);182assert_eq!(recv.len(), 2);183184// Abstraction: we merge buffer state with port state so we can map185// to one three possible 'effective' states:186// no data now (_blocked); data available (); or no data anymore (_done)187let left_done = recv[0] == PortState::Done && self.left_unmerged.is_empty();188let right_done = recv[1] == PortState::Done && self.right_unmerged.is_empty();189190// We're done as soon as one side is done.191if send[0] == PortState::Done || (left_done && right_done) {192recv[0] = PortState::Done;193recv[1] = PortState::Done;194send[0] = PortState::Done;195return Ok(());196}197198// Each port is ready to proceed unless one of the other ports is effectively199// blocked. For example:200// - [Blocked with empty buffer, Ready] [Ready] returns [Ready, Blocked] [Blocked]201// - [Blocked with non-empty buffer, Ready] [Ready] returns [Ready, Ready, Ready]202let send_blocked = send[0] == PortState::Blocked;203let left_blocked = recv[0] == PortState::Blocked && self.left_unmerged.is_empty();204let right_blocked = recv[1] == PortState::Blocked && self.right_unmerged.is_empty();205send[0] = if left_blocked || right_blocked {206PortState::Blocked207} else {208PortState::Ready209};210recv[0] = if send_blocked || right_blocked {211PortState::Blocked212} else {213PortState::Ready214};215recv[1] = if send_blocked || left_blocked {216PortState::Blocked217} else {218PortState::Ready219};220221Ok(())222}223224fn spawn<'env, 's>(225&'env mut self,226scope: &'s TaskScope<'s, 'env>,227recv_ports: &mut [Option<RecvPort<'_>>],228send_ports: &mut [Option<SendPort<'_>>],229_state: &'s StreamingExecutionState,230join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,231) {232assert_eq!(recv_ports.len(), 2);233assert_eq!(send_ports.len(), 1);234235let send = send_ports[0].take().unwrap().parallel();236237let seq = &mut self.seq;238let starting_nulls = &mut self.starting_nulls;239let left_unmerged = &mut self.left_unmerged;240let right_unmerged = &mut self.right_unmerged;241242match (recv_ports[0].take(), recv_ports[1].take()) {243// If we do not need to merge or flush anymore, just start passing the port in244// parallel.245(Some(port), None) | (None, Some(port))246if left_unmerged.is_empty() && right_unmerged.is_empty() =>247{248let recv = port.parallel();249let inner_handles = recv250.into_iter()251.zip(send)252.map(|(mut recv, mut send)| {253let morsel_offset = *seq;254scope.spawn_task(TaskPriority::High, async move {255let mut max_seq = morsel_offset;256while let Ok(mut morsel) = recv.recv().await {257// Ensure the morsel sequence id stream is monotone non-decreasing.258let seq = morsel.seq().offset_by(morsel_offset);259max_seq = max_seq.max(seq);260261remove_key_column(morsel.df_mut());262263morsel.set_seq(seq);264if send.send(morsel).await.is_err() {265break;266}267}268max_seq269})270})271.collect::<Vec<_>>();272273join_handles.push(scope.spawn_task(TaskPriority::High, async move {274// Update our global maximum.275for handle in inner_handles {276*seq = (*seq).max(handle.await);277}278Ok(())279}));280},281282// This is the base case. Either:283// - Both streams are still open and we still need to merge.284// - One or both streams are closed stream is closed and we still have some buffered285// data.286(left, right) => {287async fn buffer_unmerged(288port: &mut Receiver<Morsel>,289unmerged: &mut VecDeque<DataFrame>,290) {291// If a stop was requested, we need to buffer the remaining292// morsels and trigger a phase transition.293let Ok(morsel) = port.recv().await else {294return;295};296297// Request the port stop producing morsels.298morsel.source_token().stop();299300// Buffer all the morsels that were already produced.301unmerged.push_back(morsel.into_df());302while let Ok(morsel) = port.recv().await {303unmerged.push_back(morsel.into_df());304}305}306307let (mut distributor, dist_recv) =308distributor_channel(send.len(), *DEFAULT_DISTRIBUTOR_BUFFER_SIZE);309310let mut left = left.map(|p| p.serial());311let mut right = right.map(|p| p.serial());312313join_handles.push(scope.spawn_task(TaskPriority::Low, async move {314let source_token = SourceToken::new();315316// While we can still load data for the empty side.317while (left.is_some() || right.is_some())318&& !(left.is_none() && left_unmerged.is_empty())319&& !(right.is_none() && right_unmerged.is_empty())320{321// If we have morsels from both input ports, find until where we can merge322// them and send that on to be merged.323while let Some((left_mergeable, right_mergeable)) = find_mergeable(324left_unmerged,325right_unmerged,326seq.to_u64() == 0,327starting_nulls,328)? {329let left_mergeable =330Morsel::new(left_mergeable, *seq, source_token.clone());331*seq = seq.successor();332333if distributor334.send((left_mergeable, right_mergeable))335.await336.is_err()337{338return Ok(());339};340}341342if source_token.stop_requested() {343// Request that a port stops producing morsels and buffers all the344// remaining morsels.345if let Some(p) = &mut left {346buffer_unmerged(p, left_unmerged).await;347}348if let Some(p) = &mut right {349buffer_unmerged(p, right_unmerged).await;350}351break;352}353354assert!(left_unmerged.is_empty() || right_unmerged.is_empty());355let (empty_port, empty_unmerged) = match (356left_unmerged.is_empty(),357right_unmerged.is_empty(),358left.as_mut(),359right.as_mut(),360) {361(true, _, Some(left), _) => (left, &mut *left_unmerged),362(_, true, _, Some(right)) => (right, &mut *right_unmerged),363364// If the port that is empty is closed, we don't need to merge anymore.365_ => break,366};367368// Try to get a new morsel from the empty side.369let Ok(m) = empty_port.recv().await else {370if let Some(p) = &mut left {371buffer_unmerged(p, left_unmerged).await;372}373if let Some(p) = &mut right {374buffer_unmerged(p, right_unmerged).await;375}376break;377};378empty_unmerged.push_back(m.into_df());379}380381// Clear out buffers until we cannot anymore. This helps allows us to go to the382// parallel case faster.383while let Some((left_mergeable, right_mergeable)) = find_mergeable(384left_unmerged,385right_unmerged,386seq.to_u64() == 0,387starting_nulls,388)? {389let left_mergeable =390Morsel::new(left_mergeable, *seq, source_token.clone());391*seq = seq.successor();392393if distributor394.send((left_mergeable, right_mergeable))395.await396.is_err()397{398return Ok(());399};400}401402// If one of the ports is done and does not have buffered data anymore, we403// flush the data on the other side. After this point, this node just pipes404// data through.405let pass = if left.is_none() && left_unmerged.is_empty() {406Some((right.as_mut(), &mut *right_unmerged))407} else if right.is_none() && right_unmerged.is_empty() {408Some((left.as_mut(), &mut *left_unmerged))409} else {410None411};412if let Some((pass_port, pass_unmerged)) = pass {413for df in std::mem::take(pass_unmerged) {414let m = Morsel::new(df, *seq, source_token.clone());415*seq = seq.successor();416if distributor.send((m, DataFrame::empty())).await.is_err() {417return Ok(());418}419}420421// Start passing on the port that is port that is still open.422if let Some(pass_port) = pass_port {423let Ok(mut m) = pass_port.recv().await else {424return Ok(());425};426if source_token.stop_requested() {427m.source_token().stop();428}429m.set_seq(*seq);430*seq = seq.successor();431if distributor.send((m, DataFrame::empty())).await.is_err() {432return Ok(());433}434435while let Ok(mut m) = pass_port.recv().await {436m.set_seq(*seq);437*seq = seq.successor();438if distributor.send((m, DataFrame::empty())).await.is_err() {439return Ok(());440}441}442}443}444445Ok(())446}));447448// Task that actually merges the two dataframes. Since this merge might be very449// expensive, this is split over several tasks.450join_handles.extend(dist_recv.into_iter().zip(send).map(|(mut recv, mut send)| {451let ideal_morsel_size = get_ideal_morsel_size();452scope.spawn_task(TaskPriority::High, async move {453while let Ok((mut left, mut right)) = recv.recv().await {454// When we are flushing the buffer, we will just send one morsel from455// the input. We don't want to mess with the source token or wait group456// and just pass it on.457if right.is_empty() {458remove_key_column(left.df_mut());459460if send.send(left).await.is_err() {461return Ok(());462}463continue;464}465466let (mut left, seq, source_token, wg) = left.into_inner();467assert!(wg.is_none());468469let left_s = left470.get_columns()471.last()472.unwrap()473.as_materialized_series()474.clone();475let right_s = right476.get_columns()477.last()478.unwrap()479.as_materialized_series()480.clone();481482remove_key_column(&mut left);483remove_key_column(&mut right);484485let merged =486_merge_sorted_dfs(&left, &right, &left_s, &right_s, false)?;487488if ideal_morsel_size > 1 && merged.height() > ideal_morsel_size {489// The merged dataframe will have at most doubled in size from the490// input so we can divide by half.491let (m1, m2) = merged.split_at((merged.height() / 2) as i64);492493// MorselSeq have to be monotonely non-decreasing so we can494// pass the same sequence token twice.495let morsel = Morsel::new(m1, seq, source_token.clone());496if send.send(morsel).await.is_err() {497break;498}499let morsel = Morsel::new(m2, seq, source_token.clone());500if send.send(morsel).await.is_err() {501break;502}503} else {504let morsel = Morsel::new(merged, seq, source_token.clone());505if send.send(morsel).await.is_err() {506break;507}508}509}510511Ok(())512})513}));514},515}516}517}518519520