Path: blob/main/crates/polars-stream/src/nodes/joins/cross_join.rs
8479 views
use std::sync::Arc;12use arrow::array::builder::ShareStrategy;3use polars_core::frame::builder::DataFrameBuilder;4use polars_core::schema::Schema;5use polars_error::polars_warn;6use polars_ops::frame::{JoinArgs, JoinBuildSide, MaintainOrderJoin};7use polars_utils::format_pl_smallstr;8use polars_utils::pl_str::PlSmallStr;910use crate::morsel::get_ideal_morsel_size;11use crate::nodes::compute_node_prelude::*;12use crate::nodes::in_memory_sink::InMemorySinkNode;1314pub struct CrossJoinNode {15left_is_build: bool,16left_input_schema: Arc<Schema>,17right_input_schema: Arc<Schema>,18right_rename: Vec<Option<PlSmallStr>>,19state: CrossJoinState,20}2122impl CrossJoinNode {23pub fn new(24left_input_schema: Arc<Schema>,25right_input_schema: Arc<Schema>,26args: &JoinArgs,27) -> Self {28let left_is_build = match args.maintain_order {29MaintainOrderJoin::None => match args.build_side {30// TODO: size estimation.31None | Some(JoinBuildSide::PreferLeft) | Some(JoinBuildSide::ForceLeft) => true,32Some(JoinBuildSide::PreferRight) | Some(JoinBuildSide::ForceRight) => false,33},34MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => {35if args.build_side == Some(JoinBuildSide::ForceLeft) {36polars_warn!("can't force left build-side with left-maintaining cross-join");37}38false39},40MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => {41if args.build_side == Some(JoinBuildSide::ForceRight) {42polars_warn!("can't force right build-side with right-maintaining cross-join");43}44true45},46};47let build_input_schema = if left_is_build {48&left_input_schema49} else {50&right_input_schema51};52let sink_node = InMemorySinkNode::new(build_input_schema.clone());53let right_rename = right_input_schema54.iter_names()55.map(|rname| {56if left_input_schema.contains(rname) {57Some(format_pl_smallstr!("{}{}", rname, args.suffix()))58} else {59None60}61})62.collect();6364Self {65left_is_build,66left_input_schema,67right_input_schema,68right_rename,69state: CrossJoinState::Build(sink_node),70}71}72}7374enum CrossJoinState {75Build(InMemorySinkNode),76Probe(DataFrame),77Done,78}7980impl ComputeNode for CrossJoinNode {81fn name(&self) -> &str {82"cross-join"83}8485fn is_memory_intensive_pipeline_blocker(&self) -> bool {86true87}8889fn update_state(90&mut self,91recv: &mut [PortState],92send: &mut [PortState],93_state: &StreamingExecutionState,94) -> PolarsResult<()> {95assert!(recv.len() == 2 && send.len() == 1);9697let build_idx = if self.left_is_build { 0 } else { 1 };98let probe_idx = 1 - build_idx;99100// Are we done?101if send[0] == PortState::Done || recv[probe_idx] == PortState::Done {102self.state = CrossJoinState::Done;103}104105// Transition to build?106if recv[build_idx] == PortState::Done {107if let CrossJoinState::Build(sink_node) = &mut self.state {108let df = sink_node.get_output()?.unwrap();109if df.height() > 0 {110self.state = CrossJoinState::Probe(df);111} else {112self.state = CrossJoinState::Done;113}114}115}116117match &self.state {118CrossJoinState::Build(_) => {119recv[build_idx] = PortState::Ready;120recv[probe_idx] = PortState::Blocked;121send[0] = PortState::Blocked;122},123CrossJoinState::Probe(_) => {124recv[build_idx] = PortState::Done;125core::mem::swap(&mut recv[probe_idx], &mut send[0]);126},127CrossJoinState::Done => {128recv[0] = PortState::Done;129recv[1] = PortState::Done;130send[0] = PortState::Done;131},132}133Ok(())134}135136fn spawn<'env, 's>(137&'env mut self,138scope: &'s TaskScope<'s, 'env>,139recv_ports: &mut [Option<RecvPort<'_>>],140send_ports: &mut [Option<SendPort<'_>>],141state: &'s StreamingExecutionState,142join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,143) {144assert!(recv_ports.len() == 2 && send_ports.len() == 1);145let build_idx = if self.left_is_build { 0 } else { 1 };146let probe_idx = 1 - build_idx;147match &mut self.state {148CrossJoinState::Build(sink_node) => {149assert!(send_ports[0].is_none());150assert!(recv_ports[probe_idx].is_none());151sink_node.spawn(152scope,153&mut recv_ports[build_idx..build_idx + 1],154&mut [],155state,156join_handles,157);158},159CrossJoinState::Probe(build_df) => {160assert!(recv_ports[build_idx].is_none());161let receivers = recv_ports[probe_idx].take().unwrap().parallel();162let senders = send_ports[0].take().unwrap().parallel();163let ideal_morsel_size = get_ideal_morsel_size();164165for (mut recv, mut send) in receivers.into_iter().zip(senders) {166let left_is_build = self.left_is_build;167let left_input_schema = self.left_input_schema.clone();168let right_input_schema = self.right_input_schema.clone();169let right_rename = &self.right_rename;170let build_df = &*build_df;171join_handles.push(172scope.spawn_task(TaskPriority::High, async move {173let mut build_repeater = DataFrameBuilder::new(left_input_schema);174let mut probe_repeater = DataFrameBuilder::new(right_input_schema);175if !left_is_build {176core::mem::swap(&mut build_repeater, &mut probe_repeater);177}178let mut cached_build_df_repeated = DataFrame::empty();179180while let Ok(morsel) = recv.recv().await {181let combine =182|build_join_df: DataFrame, probe_join_df: DataFrame| unsafe {183let (mut left_join_df, mut right_join_df);184left_join_df = build_join_df;185right_join_df = probe_join_df;186if !left_is_build {187core::mem::swap(&mut left_join_df, &mut right_join_df);188}189190for (col, opt_rename) in191right_join_df.columns_mut().iter_mut().zip(right_rename)192{193if let Some(rename) = opt_rename {194col.rename(rename.clone());195}196}197198left_join_df.hstack_mut_unchecked(right_join_df.columns());199Morsel::new(200left_join_df,201morsel.seq(),202morsel.source_token().clone(),203)204};205206let probe_df = morsel.df();207if build_df.height() >= ideal_morsel_size {208for probe_offset in 0..probe_df.height() {209let mut build_offset = 0;210while build_offset < build_df.height() {211let height = (build_df.height() - build_offset)212.min(ideal_morsel_size);213let build_join_df =214build_df.slice(build_offset as i64, height);215let probe_join_df =216probe_df.new_from_index(probe_offset, height);217let combined = combine(build_join_df, probe_join_df);218if send.send(combined).await.is_err() {219return Ok(());220}221build_offset += height;222}223}224} else {225let max_build_repeats = ideal_morsel_size / build_df.height();226let mut probe_offset = 0;227while probe_offset < probe_df.height() {228let build_repeats = (probe_df.height() - probe_offset)229.min(max_build_repeats);230let build_height = build_repeats * build_df.height();231if build_height > cached_build_df_repeated.height() {232build_repeater.subslice_extend_repeated(233build_df,2340,235build_df.height(),236build_repeats,237ShareStrategy::Never,238);239cached_build_df_repeated =240build_repeater.freeze_reset();241}242let build_join_df =243cached_build_df_repeated.slice(0, build_height);244245probe_repeater.subslice_extend_each_repeated(246probe_df,247probe_offset,248build_repeats,249build_df.height(),250ShareStrategy::Always,251);252let probe_join_df = probe_repeater.freeze_reset();253254let combined = combine(build_join_df, probe_join_df);255if send.send(combined).await.is_err() {256return Ok(());257}258259probe_offset += build_repeats;260}261}262}263Ok(())264}),265);266}267},268CrossJoinState::Done => unreachable!(),269}270}271}272273274