Path: blob/main/crates/polars-stream/src/nodes/joins/cross_join.rs
6939 views
use std::sync::Arc;12use arrow::array::builder::ShareStrategy;3use polars_core::frame::builder::DataFrameBuilder;4use polars_core::schema::Schema;5use polars_ops::frame::{JoinArgs, MaintainOrderJoin};6use polars_utils::format_pl_smallstr;7use polars_utils::pl_str::PlSmallStr;89use crate::morsel::get_ideal_morsel_size;10use crate::nodes::compute_node_prelude::*;11use crate::nodes::in_memory_sink::InMemorySinkNode;1213pub struct CrossJoinNode {14left_is_build: bool,15left_input_schema: Arc<Schema>,16right_input_schema: Arc<Schema>,17right_rename: Vec<Option<PlSmallStr>>,18state: CrossJoinState,19}2021impl CrossJoinNode {22pub fn new(23left_input_schema: Arc<Schema>,24right_input_schema: Arc<Schema>,25args: &JoinArgs,26) -> Self {27let left_is_build = match args.maintain_order {28MaintainOrderJoin::None => true, // TODO: size estimation.29MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => false,30MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => true,31};32let build_input_schema = if left_is_build {33&left_input_schema34} else {35&right_input_schema36};37let sink_node = InMemorySinkNode::new(build_input_schema.clone());38let right_rename = right_input_schema39.iter_names()40.map(|rname| {41if left_input_schema.contains(rname) {42Some(format_pl_smallstr!("{}{}", rname, args.suffix()))43} else {44None45}46})47.collect();4849Self {50left_is_build,51left_input_schema,52right_input_schema,53right_rename,54state: CrossJoinState::Build(sink_node),55}56}57}5859enum CrossJoinState {60Build(InMemorySinkNode),61Probe(DataFrame),62Done,63}6465impl ComputeNode for CrossJoinNode {66fn name(&self) -> &str {67"cross-join"68}6970fn is_memory_intensive_pipeline_blocker(&self) -> bool {71true72}7374fn update_state(75&mut self,76recv: &mut [PortState],77send: &mut [PortState],78_state: &StreamingExecutionState,79) -> PolarsResult<()> {80assert!(recv.len() == 2 && send.len() == 1);8182let build_idx = if self.left_is_build { 0 } else { 1 };83let probe_idx = 1 - build_idx;8485// Are we done?86if send[0] == PortState::Done || recv[probe_idx] == PortState::Done {87self.state = CrossJoinState::Done;88}8990// Transition to build?91if recv[build_idx] == PortState::Done {92if let CrossJoinState::Build(sink_node) = &mut self.state {93let df = sink_node.get_output()?.unwrap();94if df.height() > 0 {95self.state = CrossJoinState::Probe(df);96} else {97self.state = CrossJoinState::Done;98}99}100}101102match &self.state {103CrossJoinState::Build(_) => {104recv[build_idx] = PortState::Ready;105recv[probe_idx] = PortState::Blocked;106send[0] = PortState::Blocked;107},108CrossJoinState::Probe(_) => {109recv[build_idx] = PortState::Done;110core::mem::swap(&mut recv[probe_idx], &mut send[0]);111},112CrossJoinState::Done => {113recv[0] = PortState::Done;114recv[1] = PortState::Done;115send[0] = PortState::Done;116},117}118Ok(())119}120121fn spawn<'env, 's>(122&'env mut self,123scope: &'s TaskScope<'s, 'env>,124recv_ports: &mut [Option<RecvPort<'_>>],125send_ports: &mut [Option<SendPort<'_>>],126state: &'s StreamingExecutionState,127join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,128) {129assert!(recv_ports.len() == 2 && send_ports.len() == 1);130let build_idx = if self.left_is_build { 0 } else { 1 };131let probe_idx = 1 - build_idx;132match &mut self.state {133CrossJoinState::Build(sink_node) => {134assert!(send_ports[0].is_none());135assert!(recv_ports[probe_idx].is_none());136sink_node.spawn(137scope,138&mut recv_ports[build_idx..build_idx + 1],139&mut [],140state,141join_handles,142);143},144CrossJoinState::Probe(build_df) => {145assert!(recv_ports[build_idx].is_none());146let receivers = recv_ports[probe_idx].take().unwrap().parallel();147let senders = send_ports[0].take().unwrap().parallel();148let ideal_morsel_size = get_ideal_morsel_size();149150for (mut recv, mut send) in receivers.into_iter().zip(senders) {151let left_is_build = self.left_is_build;152let left_input_schema = self.left_input_schema.clone();153let right_input_schema = self.right_input_schema.clone();154let right_rename = &self.right_rename;155let build_df = &*build_df;156join_handles.push(157scope.spawn_task(TaskPriority::High, async move {158let mut build_repeater = DataFrameBuilder::new(left_input_schema);159let mut probe_repeater = DataFrameBuilder::new(right_input_schema);160if !left_is_build {161core::mem::swap(&mut build_repeater, &mut probe_repeater);162}163let mut cached_build_df_repeated = DataFrame::empty();164165while let Ok(morsel) = recv.recv().await {166let combine =167|build_join_df: DataFrame, probe_join_df: DataFrame| unsafe {168let (mut left_join_df, mut right_join_df);169left_join_df = build_join_df;170right_join_df = probe_join_df;171if !left_is_build {172core::mem::swap(&mut left_join_df, &mut right_join_df);173}174175for (col, opt_rename) in right_join_df176.get_columns_mut()177.iter_mut()178.zip(right_rename)179{180if let Some(rename) = opt_rename {181col.rename(rename.clone());182}183}184185left_join_df186.hstack_mut_unchecked(right_join_df.get_columns());187Morsel::new(188left_join_df,189morsel.seq(),190morsel.source_token().clone(),191)192};193194let probe_df = morsel.df();195if build_df.height() >= ideal_morsel_size {196for probe_offset in 0..probe_df.height() {197let mut build_offset = 0;198while build_offset < build_df.height() {199let height = (build_df.height() - build_offset)200.min(ideal_morsel_size);201let build_join_df =202build_df.slice(build_offset as i64, height);203let probe_join_df =204probe_df.new_from_index(probe_offset, height);205let combined = combine(build_join_df, probe_join_df);206if send.send(combined).await.is_err() {207return Ok(());208}209build_offset += height;210}211}212} else {213let max_build_repeats = ideal_morsel_size / build_df.height();214let mut probe_offset = 0;215while probe_offset < probe_df.height() {216let build_repeats = (probe_df.height() - probe_offset)217.min(max_build_repeats);218let build_height = build_repeats * build_df.height();219if build_height > cached_build_df_repeated.height() {220build_repeater.subslice_extend_repeated(221build_df,2220,223build_df.height(),224build_repeats,225ShareStrategy::Never,226);227cached_build_df_repeated =228build_repeater.freeze_reset();229}230let build_join_df =231cached_build_df_repeated.slice(0, build_height);232233probe_repeater.subslice_extend_each_repeated(234probe_df,235probe_offset,236build_repeats,237build_df.height(),238ShareStrategy::Always,239);240let probe_join_df = probe_repeater.freeze_reset();241242let combined = combine(build_join_df, probe_join_df);243if send.send(combined).await.is_err() {244return Ok(());245}246247probe_offset += build_repeats;248}249}250}251Ok(())252}),253);254}255},256CrossJoinState::Done => unreachable!(),257}258}259}260261262