Path: blob/main/crates/polars-stream/src/nodes/merge_sorted.rs
8479 views
use std::collections::VecDeque;12use polars_core::prelude::ChunkCompareIneq;3use polars_ops::frame::_merge_sorted_dfs;45use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;6use crate::async_primitives::distributor_channel::distributor_channel;7use crate::morsel::{SourceToken, get_ideal_morsel_size};8use crate::nodes::compute_node_prelude::*;910/// Performs `merge_sorted` with the last column being regarded as the key column. This key column11/// is also popped in the send pipe.12pub struct MergeSortedNode {13seq: MorselSeq,1415starting_nulls: bool,1617// Not yet merged buffers.18left_unmerged: VecDeque<DataFrame>,19right_unmerged: VecDeque<DataFrame>,20}2122impl MergeSortedNode {23pub fn new() -> Self {24Self {25seq: MorselSeq::default(),2627starting_nulls: false,2829left_unmerged: VecDeque::new(),30right_unmerged: VecDeque::new(),31}32}33}3435/// Find a part amongst both unmerged buffers which is mergeable.36///37/// This returns `None` if there is nothing mergeable at this point.38fn find_mergeable(39left_unmerged: &mut VecDeque<DataFrame>,40right_unmerged: &mut VecDeque<DataFrame>,4142is_first: bool,43starting_nulls: &mut bool,44) -> PolarsResult<Option<(DataFrame, DataFrame)>> {45fn first_non_empty(vd: &mut VecDeque<DataFrame>) -> Option<DataFrame> {46let mut df = vd.pop_front()?;47while df.height() == 0 {48df = vd.pop_front()?;49}50Some(df)51}5253loop {54let (mut left, mut right) = match (55first_non_empty(left_unmerged),56first_non_empty(right_unmerged),57) {58(Some(l), Some(r)) => (l, r),59(Some(l), None) => {60left_unmerged.push_front(l);61return Ok(None);62},63(None, Some(r)) => {64right_unmerged.push_front(r);65return Ok(None);66},67(None, None) => return Ok(None),68};6970let left_key = left.columns().last().unwrap();71let right_key = right.columns().last().unwrap();7273let left_null_count = left_key.null_count();74let right_null_count = right_key.null_count();7576let has_nulls = left_null_count > 0 || right_null_count > 0;7778// If we are on the first morsel we need to decide whether we have79// nulls first or not.80if is_first81&& has_nulls82&& (left_key.head(Some(1)).has_nulls() || right_key.head(Some(1)).has_nulls())83{84*starting_nulls = true;85}8687// For both left and right, find row index of the minimum of the maxima88// of the left and right key columns. We can safely merge until this89// point.90let mut left_cutoff = left.height();91let mut right_cutoff = right.height();9293let left_key_last = left_key.tail(Some(1));94let right_key_last = right_key.tail(Some(1));9596// We already made sure we had data to work with.97assert!(!left_key_last.is_empty());98assert!(!right_key_last.is_empty());99100if has_nulls {101if *starting_nulls {102// If there are starting nulls do those first, then repeat103// without the nulls.104left_cutoff = left_null_count;105right_cutoff = right_null_count;106} else {107// If there are ending nulls then first do things without the108// nulls and then repeat with only the nulls the nulls.109let left_is_all_nulls = left_null_count == left.height();110let right_is_all_nulls = right_null_count == right.height();111112match (left_is_all_nulls, right_is_all_nulls) {113(false, false) => {114let left_nulls;115let right_nulls;116(left, left_nulls) =117left.split_at((left.height() - left_null_count) as i64);118(right, right_nulls) =119right.split_at((right.height() - right_null_count) as i64);120121left_unmerged.push_front(left_nulls);122left_unmerged.push_front(left);123right_unmerged.push_front(right_nulls);124right_unmerged.push_front(right);125continue;126},127(true, false) => left_cutoff = 0,128(false, true) => right_cutoff = 0,129(true, true) => {},130}131}132} else if left_key_last.lt(&right_key_last)?.all() {133// @TODO: This is essentially search sorted, but that does not134// support categoricals at moment.135let gt_mask = right_key.gt(&left_key_last)?;136right_cutoff = gt_mask.downcast_as_array().values().leading_zeros();137} else if left_key_last.gt(&right_key_last)?.all() {138// @TODO: This is essentially search sorted, but that does not139// support categoricals at moment.140let gt_mask = left_key.gt(&right_key_last)?;141left_cutoff = gt_mask.downcast_as_array().values().leading_zeros();142}143144let left_mergeable: DataFrame;145let right_mergeable: DataFrame;146(left_mergeable, left) = left.split_at(left_cutoff as i64);147(right_mergeable, right) = right.split_at(right_cutoff as i64);148149if left.height() > 0 {150left_unmerged.push_front(left);151}152if right.height() > 0 {153right_unmerged.push_front(right);154}155156return Ok(Some((left_mergeable, right_mergeable)));157}158}159160fn remove_key_column(df: &mut DataFrame) {161// SAFETY:162// - We only pop so height stays same.163// - We only pop so no new name collisions.164// - We clear schema afterwards.165unsafe { df.columns_mut().pop().unwrap() };166}167168impl ComputeNode for MergeSortedNode {169fn name(&self) -> &str {170"merge-sorted"171}172173fn update_state(174&mut self,175recv: &mut [PortState],176send: &mut [PortState],177_state: &StreamingExecutionState,178) -> PolarsResult<()> {179assert_eq!(send.len(), 1);180assert_eq!(recv.len(), 2);181182// Abstraction: we merge buffer state with port state so we can map183// to one three possible 'effective' states:184// no data now (_blocked); data available (); or no data anymore (_done)185let left_done = recv[0] == PortState::Done && self.left_unmerged.is_empty();186let right_done = recv[1] == PortState::Done && self.right_unmerged.is_empty();187188// We're done as soon as one side is done.189if send[0] == PortState::Done || (left_done && right_done) {190recv[0] = PortState::Done;191recv[1] = PortState::Done;192send[0] = PortState::Done;193return Ok(());194}195196// Each port is ready to proceed unless one of the other ports is effectively197// blocked. For example:198// - [Blocked with empty buffer, Ready] [Ready] returns [Ready, Blocked] [Blocked]199// - [Blocked with non-empty buffer, Ready] [Ready] returns [Ready, Ready, Ready]200let send_blocked = send[0] == PortState::Blocked;201let left_blocked = recv[0] == PortState::Blocked && self.left_unmerged.is_empty();202let right_blocked = recv[1] == PortState::Blocked && self.right_unmerged.is_empty();203send[0] = if left_blocked || right_blocked {204PortState::Blocked205} else {206PortState::Ready207};208recv[0] = if send_blocked || right_blocked {209PortState::Blocked210} else {211PortState::Ready212};213recv[1] = if send_blocked || left_blocked {214PortState::Blocked215} else {216PortState::Ready217};218219Ok(())220}221222fn spawn<'env, 's>(223&'env mut self,224scope: &'s TaskScope<'s, 'env>,225recv_ports: &mut [Option<RecvPort<'_>>],226send_ports: &mut [Option<SendPort<'_>>],227_state: &'s StreamingExecutionState,228join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,229) {230assert_eq!(recv_ports.len(), 2);231assert_eq!(send_ports.len(), 1);232233let send = send_ports[0].take().unwrap().parallel();234235let seq = &mut self.seq;236let starting_nulls = &mut self.starting_nulls;237let left_unmerged = &mut self.left_unmerged;238let right_unmerged = &mut self.right_unmerged;239240match (recv_ports[0].take(), recv_ports[1].take()) {241// If we do not need to merge or flush anymore, just start passing the port in242// parallel.243(Some(port), None) | (None, Some(port))244if left_unmerged.is_empty() && right_unmerged.is_empty() =>245{246let recv = port.parallel();247let inner_handles = recv248.into_iter()249.zip(send)250.map(|(mut recv, mut send)| {251let morsel_offset = *seq;252scope.spawn_task(TaskPriority::High, async move {253let mut max_seq = morsel_offset;254while let Ok(mut morsel) = recv.recv().await {255// Ensure the morsel sequence id stream is monotone non-decreasing.256let seq = morsel.seq().offset_by(morsel_offset);257max_seq = max_seq.max(seq);258259remove_key_column(morsel.df_mut());260261morsel.set_seq(seq);262if send.send(morsel).await.is_err() {263break;264}265}266max_seq267})268})269.collect::<Vec<_>>();270271join_handles.push(scope.spawn_task(TaskPriority::High, async move {272// Update our global maximum.273for handle in inner_handles {274*seq = (*seq).max(handle.await);275}276Ok(())277}));278},279280// This is the base case. Either:281// - Both streams are still open and we still need to merge.282// - One or both streams are closed stream is closed and we still have some buffered283// data.284(left, right) => {285async fn buffer_unmerged(286port: &mut PortReceiver,287unmerged: &mut VecDeque<DataFrame>,288) {289// If a stop was requested, we need to buffer the remaining290// morsels and trigger a phase transition.291292while let Ok(morsel) = port.recv().await {293// Request the port stop producing morsels.294morsel.source_token().stop();295// Buffer all the morsels that were already produced.296unmerged.push_back(morsel.into_df());297}298}299300let (mut distributor, dist_recv) =301distributor_channel(send.len(), *DEFAULT_DISTRIBUTOR_BUFFER_SIZE);302303let mut left = left.map(|p| p.serial());304let mut right = right.map(|p| p.serial());305306join_handles.push(scope.spawn_task(TaskPriority::Low, async move {307let source_token = SourceToken::new();308309// While we can still load data for the empty side.310while (left.is_some() || right.is_some())311&& !(left.is_none() && left_unmerged.is_empty())312&& !(right.is_none() && right_unmerged.is_empty())313{314// If we have morsels from both input ports, find until where we can merge315// them and send that on to be merged.316while let Some((left_mergeable, right_mergeable)) = find_mergeable(317left_unmerged,318right_unmerged,319seq.to_u64() == 0,320starting_nulls,321)? {322let left_mergeable =323Morsel::new(left_mergeable, *seq, source_token.clone());324*seq = seq.successor();325326if distributor327.send((left_mergeable, right_mergeable))328.await329.is_err()330{331return Ok(());332};333}334335if source_token.stop_requested() {336// Request that a port stops producing morsels and buffers all the337// remaining morsels.338if let Some(p) = &mut left {339buffer_unmerged(p, left_unmerged).await;340}341if let Some(p) = &mut right {342buffer_unmerged(p, right_unmerged).await;343}344break;345}346347assert!(left_unmerged.is_empty() || right_unmerged.is_empty());348let (empty_port, empty_unmerged) = match (349left_unmerged.is_empty(),350right_unmerged.is_empty(),351left.as_mut(),352right.as_mut(),353) {354(true, _, Some(left), _) => (left, &mut *left_unmerged),355(_, true, _, Some(right)) => (right, &mut *right_unmerged),356357// If the port that is empty is closed, we don't need to merge anymore.358_ => break,359};360361// Try to get a new morsel from the empty side.362let Ok(m) = empty_port.recv().await else {363if let Some(p) = &mut left {364buffer_unmerged(p, left_unmerged).await;365}366if let Some(p) = &mut right {367buffer_unmerged(p, right_unmerged).await;368}369break;370};371empty_unmerged.push_back(m.into_df());372}373374// Clear out buffers until we cannot anymore. This helps allows us to go to the375// parallel case faster.376while let Some((left_mergeable, right_mergeable)) = find_mergeable(377left_unmerged,378right_unmerged,379seq.to_u64() == 0,380starting_nulls,381)? {382let left_mergeable =383Morsel::new(left_mergeable, *seq, source_token.clone());384*seq = seq.successor();385386if distributor387.send((left_mergeable, right_mergeable))388.await389.is_err()390{391return Ok(());392};393}394395// If one of the ports is done and does not have buffered data anymore, we396// flush the data on the other side. After this point, this node just pipes397// data through.398let pass = if left.is_none() && left_unmerged.is_empty() {399Some((right.as_mut(), &mut *right_unmerged))400} else if right.is_none() && right_unmerged.is_empty() {401Some((left.as_mut(), &mut *left_unmerged))402} else {403None404};405if let Some((pass_port, pass_unmerged)) = pass {406for df in std::mem::take(pass_unmerged) {407let m = Morsel::new(df, *seq, source_token.clone());408*seq = seq.successor();409if distributor.send((m, DataFrame::empty())).await.is_err() {410return Ok(());411}412}413414// Start passing on the port that is still open.415if let Some(pass_port) = pass_port {416let Ok(mut m) = pass_port.recv().await else {417return Ok(());418};419if source_token.stop_requested() {420m.source_token().stop();421}422m.set_seq(*seq);423*seq = seq.successor();424if distributor.send((m, DataFrame::empty())).await.is_err() {425return Ok(());426}427428while let Ok(mut m) = pass_port.recv().await {429m.set_seq(*seq);430*seq = seq.successor();431if distributor.send((m, DataFrame::empty())).await.is_err() {432return Ok(());433}434}435}436}437438Ok(())439}));440441// Task that actually merges the two dataframes. Since this merge might be very442// expensive, this is split over several tasks.443join_handles.extend(dist_recv.into_iter().zip(send).map(|(mut recv, mut send)| {444let ideal_morsel_size = get_ideal_morsel_size();445scope.spawn_task(TaskPriority::High, async move {446while let Ok((mut left, mut right)) = recv.recv().await {447// When we are flushing the buffer, we will just send one morsel from448// the input. We don't want to mess with the source token or wait group449// and just pass it on.450if right.shape_has_zero() {451remove_key_column(left.df_mut());452453if send.send(left).await.is_err() {454return Ok(());455}456continue;457}458459let (mut left, seq, source_token, wg) = left.into_inner();460assert!(wg.is_none());461462let left_s = left463.columns()464.last()465.unwrap()466.as_materialized_series()467.clone();468let right_s = right469.columns()470.last()471.unwrap()472.as_materialized_series()473.clone();474475remove_key_column(&mut left);476remove_key_column(&mut right);477478let merged =479_merge_sorted_dfs(&left, &right, &left_s, &right_s, false)?;480481if ideal_morsel_size > 1 && merged.height() > ideal_morsel_size {482// The merged dataframe will have at most doubled in size from the483// input so we can divide by half.484let (m1, m2) = merged.split_at((merged.height() / 2) as i64);485486// MorselSeq have to be monotonely non-decreasing so we can487// pass the same sequence token twice.488let morsel = Morsel::new(m1, seq, source_token.clone());489if send.send(morsel).await.is_err() {490break;491}492let morsel = Morsel::new(m2, seq, source_token.clone());493if send.send(morsel).await.is_err() {494break;495}496} else {497let morsel = Morsel::new(merged, seq, source_token.clone());498if send.send(morsel).await.is_err() {499break;500}501}502}503504Ok(())505})506}));507},508}509}510}511512513