Path: blob/main/crates/polars-stream/src/nodes/joins/asof_join.rs
8480 views
use std::collections::VecDeque;12use polars_core::prelude::*;3use polars_core::utils::Container;4use polars_ops::frame::{AsOfOptions, AsofStrategy, JoinArgs, JoinType};5use polars_utils::format_pl_smallstr;67use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;8use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};9use crate::async_primitives::distributor_channel as dc;10use crate::execute::StreamingExecutionState;11use crate::graph::PortState;12use crate::morsel::{Morsel, MorselSeq, SourceToken};13use crate::nodes::ComputeNode;14use crate::nodes::joins::utils::DataFrameSearchBuffer;15use crate::pipe::{PortReceiver, PortSender, RecvPort, SendPort};1617#[derive(Debug)]18pub struct AsOfJoinSideParams {19pub on: PlSmallStr,20pub tmp_key_col: Option<PlSmallStr>,21}2223impl AsOfJoinSideParams {24fn key_col(&self) -> &PlSmallStr {25self.tmp_key_col.as_ref().unwrap_or(&self.on)26}27}2829#[derive(Debug)]30struct AsOfJoinParams {31left: AsOfJoinSideParams,32right: AsOfJoinSideParams,33args: JoinArgs,34}3536impl AsOfJoinParams {37fn as_of_options(&self) -> &AsOfOptions {38let JoinType::AsOf(ref options) = self.args.how else {39unreachable!("incorrect join type");40};41options42}43}4445#[derive(Debug, Default, PartialEq)]46enum AsOfJoinState {47#[default]48Running,49FlushInputBuffer,50Done,51}5253#[derive(Debug)]54pub struct AsOfJoinNode {55params: AsOfJoinParams,56state: AsOfJoinState,57/// We may need to stash a morsel on the left side whenever we do not58/// have enough data on the right side, but the right side is empty.59/// In these cases, we stash that morsel here.60left_buffer: VecDeque<(DataFrame, MorselSeq)>,61/// Buffer of the live range of right AsOf join rows.62right_buffer: DataFrameSearchBuffer,63}6465impl AsOfJoinNode {66pub fn new(67left_input_schema: SchemaRef,68right_input_schema: SchemaRef,69left_on: PlSmallStr,70right_on: PlSmallStr,71tmp_left_key_col: Option<PlSmallStr>,72tmp_right_key_col: Option<PlSmallStr>,73args: JoinArgs,74) -> Self {75let left_key_col = tmp_left_key_col.as_ref().unwrap_or(&left_on);76let right_key_col = tmp_right_key_col.as_ref().unwrap_or(&right_on);77let left_key_dtype = left_input_schema.get(left_key_col).unwrap();78let right_key_dtype = right_input_schema.get(right_key_col).unwrap();79assert_eq!(left_key_dtype, right_key_dtype);80let left = AsOfJoinSideParams {81on: left_on,82tmp_key_col: tmp_left_key_col,83};84let right = AsOfJoinSideParams {85on: right_on,86tmp_key_col: tmp_right_key_col,87};8889let params = AsOfJoinParams { left, right, args };90AsOfJoinNode {91params,92state: AsOfJoinState::default(),93left_buffer: Default::default(),94right_buffer: DataFrameSearchBuffer::empty_with_schema(right_input_schema),95}96}97}9899impl ComputeNode for AsOfJoinNode {100fn name(&self) -> &str {101"asof-join"102}103104fn update_state(105&mut self,106recv: &mut [PortState],107send: &mut [PortState],108_state: &StreamingExecutionState,109) -> PolarsResult<()> {110assert!(recv.len() == 2 && send.len() == 1);111112if send[0] == PortState::Done {113self.state = AsOfJoinState::Done;114}115116if self.state == AsOfJoinState::Running && recv[0] == PortState::Done {117self.state = AsOfJoinState::FlushInputBuffer;118}119120if self.state == AsOfJoinState::FlushInputBuffer && self.left_buffer.is_empty() {121self.state = AsOfJoinState::Done;122}123124let recv0_blocked = recv[0] == PortState::Blocked;125let recv1_blocked = recv[1] == PortState::Blocked;126let send_blocked = send[0] == PortState::Blocked;127match self.state {128AsOfJoinState::Running => {129recv[0] = PortState::Ready;130recv[1] = PortState::Ready;131send[0] = PortState::Ready;132if recv0_blocked {133recv[1] = PortState::Blocked;134send[0] = PortState::Blocked;135}136if recv1_blocked {137recv[0] = PortState::Blocked;138send[0] = PortState::Blocked;139}140if send_blocked {141recv[0] = PortState::Blocked;142recv[1] = PortState::Blocked;143}144},145AsOfJoinState::FlushInputBuffer => {146recv[0] = PortState::Done;147recv[1] = PortState::Ready;148send[0] = PortState::Ready;149if recv1_blocked {150send[0] = PortState::Blocked;151}152if send_blocked {153recv[1] = PortState::Blocked;154}155},156AsOfJoinState::Done => {157recv.fill(PortState::Done);158send[0] = PortState::Done;159},160}161162Ok(())163}164165fn spawn<'env, 's>(166&'env mut self,167scope: &'s TaskScope<'s, 'env>,168recv_ports: &mut [Option<RecvPort<'_>>],169send_ports: &mut [Option<SendPort<'_>>],170_state: &'s StreamingExecutionState,171join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,172) {173assert!(recv_ports.len() == 2 && send_ports.len() == 1);174175match &self.state {176AsOfJoinState::Running | AsOfJoinState::FlushInputBuffer => {177let params = &self.params;178let recv_left = match self.state {179AsOfJoinState::Running => Some(recv_ports[0].take().unwrap().serial()),180_ => None,181};182let recv_right = recv_ports[1].take().map(RecvPort::serial);183let send = send_ports[0].take().unwrap().parallel();184let (distributor, dist_recv) =185dc::distributor_channel(send.len(), *DEFAULT_DISTRIBUTOR_BUFFER_SIZE);186let left_buffer = &mut self.left_buffer;187let right_buffer = &mut self.right_buffer;188join_handles.push(scope.spawn_task(TaskPriority::High, async move {189distribute_work_task(190recv_left,191recv_right,192distributor,193left_buffer,194right_buffer,195params,196)197.await198}));199200join_handles.extend(dist_recv.into_iter().zip(send).map(|(recv, send)| {201scope.spawn_task(TaskPriority::High, async move {202compute_and_emit_task(recv, send, params).await203})204}));205},206AsOfJoinState::Done => {207unreachable!();208},209}210}211}212213/// Tell the sender to this port to stop, and buffer everything that is still in the pipe.214async fn stop_and_buffer_pipe_contents<F>(port: Option<&mut PortReceiver>, buffer_morsel: &mut F)215where216F: FnMut(DataFrame, MorselSeq),217{218let Some(port) = port else {219return;220};221222while let Ok(morsel) = port.recv().await {223morsel.source_token().stop();224let (df, seq, _, _) = morsel.into_inner();225buffer_morsel(df, seq);226}227}228229async fn distribute_work_task(230mut recv_left: Option<PortReceiver>,231mut recv_right: Option<PortReceiver>,232mut distributor: dc::Sender<(DataFrame, DataFrameSearchBuffer, MorselSeq, SourceToken)>,233left_buffer: &mut VecDeque<(DataFrame, MorselSeq)>,234right_buffer: &mut DataFrameSearchBuffer,235params: &AsOfJoinParams,236) -> PolarsResult<()> {237let source_token = SourceToken::new();238let right_done = recv_right.is_none();239240loop {241if source_token.stop_requested() {242stop_and_buffer_pipe_contents(recv_left.as_mut(), &mut |df, seq| {243left_buffer.push_back((df, seq))244})245.await;246stop_and_buffer_pipe_contents(recv_right.as_mut(), &mut |df, _| {247right_buffer.push_df(df)248})249.await;250return Ok(());251}252253let (left_df, seq, st) = if let Some((df, seq)) = left_buffer.pop_front() {254(df, seq, source_token.clone())255} else if let Some(ref mut recv) = recv_left256&& let Ok(m) = recv.recv().await257{258let (df, seq, st, _) = m.into_inner();259(df, seq, st)260} else {261stop_and_buffer_pipe_contents(recv_right.as_mut(), &mut |df, _| {262right_buffer.push_df(df)263})264.await;265return Ok(());266};267268while need_more_right_side(&left_df, right_buffer, params)? && !right_done {269if let Some(ref mut recv) = recv_right270&& let Ok(morsel_right) = recv.recv().await271{272right_buffer.push_df(morsel_right.into_df());273} else {274// The right pipe is empty at this stage, we will need to wait for275// a new stage and try again.276left_buffer.push_back((left_df, seq));277stop_and_buffer_pipe_contents(recv_left.as_mut(), &mut |df, seq| {278left_buffer.push_back((df, seq))279})280.await;281return Ok(());282}283}284285distributor286.send((left_df.clone(), right_buffer.clone(), seq, st))287.await288.unwrap();289prune_right_side(&left_df, right_buffer, params)?;290}291}292293/// Do we need more values on the right side before we can compute the AsOf join294/// between the right side and the complete left side?295fn need_more_right_side(296left: &DataFrame,297right: &DataFrameSearchBuffer,298params: &AsOfJoinParams,299) -> PolarsResult<bool> {300let options = params.as_of_options();301let left_key = left.column(params.left.key_col())?.as_materialized_series();302if left_key.is_empty() {303return Ok(false);304}305// SAFETY: We just checked that left_key is not empty306let left_last_val = unsafe { left_key.get_unchecked(left_key.len() - 1) };307let right_range_end = match (options.strategy, options.allow_eq) {308(AsofStrategy::Forward, true) => {309right.binary_search(|x| *x >= left_last_val, params.right.key_col(), false)310},311(AsofStrategy::Forward, false) | (AsofStrategy::Backward, true) => {312right.binary_search(|x| *x > left_last_val, params.right.key_col(), false)313},314(AsofStrategy::Backward, false) | (AsofStrategy::Nearest, _) => {315let first_greater =316right.binary_search(|x| *x > left_last_val, params.right.key_col(), false);317if first_greater >= right.height() {318return Ok(true);319}320// In the Backward/Nearest cases, there may be a chunk of consecutive equal321// values following the match value on the left side. In this case, the AsOf322// join is greedy and should until the *end* of that chunk.323324// SAFETY: We just checked that right_range_end is in bounds325let fst_greater_val =326unsafe { right.get_bypass_validity(params.right.key_col(), first_greater, false) };327right.binary_search(|x| *x > fst_greater_val, params.right.key_col(), false)328},329};330Ok(right_range_end >= right.height())331}332333fn prune_right_side(334left: &DataFrame,335right: &mut DataFrameSearchBuffer,336params: &AsOfJoinParams,337) -> PolarsResult<()> {338let left_key = left.column(params.left.key_col())?.as_materialized_series();339if left.len() == 0 {340return Ok(());341}342// SAFETY: We just checked that left_key is not empty343let left_first_val = unsafe { left_key.get_unchecked(0) };344let right_range_start = right345.binary_search(|x| *x >= left_first_val, params.right.key_col(), false)346.saturating_sub(1);347right.split_at(right_range_start);348Ok(())349}350351async fn compute_and_emit_task(352mut dist_recv: dc::Receiver<(DataFrame, DataFrameSearchBuffer, MorselSeq, SourceToken)>,353mut send: PortSender,354params: &AsOfJoinParams,355) -> PolarsResult<()> {356let options = params.as_of_options();357while let Ok((left_df, right_buffer, seq, st)) = dist_recv.recv().await {358let right_df = right_buffer.into_df();359360let left_key = left_df.column(params.left.key_col())?;361let right_key = right_df.column(params.right.key_col())?;362let any_key_is_temporary_col =363params.left.tmp_key_col.is_some() || params.right.tmp_key_col.is_some();364let mut out = polars_ops::frame::AsofJoin::_join_asof(365&left_df,366&right_df,367left_key.as_materialized_series(),368right_key.as_materialized_series(),369options.strategy,370options.tolerance.clone().map(Scalar::into_value),371params.args.suffix.clone(),372None,373any_key_is_temporary_col || params.args.should_coalesce(),374options.allow_eq,375options.check_sortedness,376)?;377378// Drop any temporary key columns that were added379for tmp_key_col in [¶ms.left.tmp_key_col, ¶ms.right.tmp_key_col] {380if let Some(tmp_col) = tmp_key_col381&& out.schema().contains(tmp_col)382{383out.drop_in_place(tmp_col)?;384}385}386387// If the join key passed to _join_asof() was a temporary key column,388// we still need to coalesce the real 'on' columns ourselves.389if any_key_is_temporary_col390&& params.args.should_coalesce()391&& params.left.on == params.right.on392{393let right_on_name = format_pl_smallstr!("{}{}", params.right.on, params.args.suffix());394out.drop_in_place(&right_on_name)?;395}396397let morsel = Morsel::new(out, seq, st);398if send.send(morsel).await.is_err() {399return Ok(());400}401}402Ok(())403}404405406