Path: blob/main/crates/polars-stream/src/nodes/joins/equi_join.rs
8460 views
use std::cmp::Reverse;1use std::collections::BinaryHeap;2use std::sync::Arc;3use std::sync::atomic::{AtomicU64, Ordering};45use arrow::array::builder::ShareStrategy;6use polars_core::frame::builder::DataFrameBuilder;7use polars_core::prelude::*;8use polars_core::schema::{Schema, SchemaExt};9use polars_core::{POOL, config};10use polars_expr::hash_keys::HashKeys;11use polars_expr::idx_table::{IdxTable, new_idx_table};12use polars_io::pl_async::get_runtime;13use polars_ops::frame::{JoinArgs, JoinBuildSide, JoinType, MaintainOrderJoin};14use polars_ops::series::coalesce_columns;15use polars_utils::cardinality_sketch::CardinalitySketch;16use polars_utils::hashing::HashPartitioner;17use polars_utils::itertools::Itertools;18use polars_utils::pl_str::PlSmallStr;19use polars_utils::priority::Priority;20use polars_utils::relaxed_cell::RelaxedCell;21use polars_utils::sparse_init_vec::SparseInitVec;22use polars_utils::{IdxSize, format_pl_smallstr};23use rayon::prelude::*;2425use super::{BufferedStream, JOIN_SAMPLE_LIMIT, LOPSIDED_SAMPLE_FACTOR};26use crate::async_executor;27use crate::async_primitives::wait_group::WaitGroup;28use crate::expression::StreamExpr;29use crate::morsel::{SourceToken, get_ideal_morsel_size};30use crate::nodes::compute_node_prelude::*;31use crate::nodes::in_memory_source::InMemorySourceNode;3233struct EquiJoinParams {34left_is_build: Option<bool>,35preserve_order_build: bool,36preserve_order_probe: bool,37left_key_schema: Arc<Schema>,38left_key_selectors: Vec<StreamExpr>,39#[allow(dead_code)]40right_key_schema: Arc<Schema>,41right_key_selectors: Vec<StreamExpr>,42left_payload_select: Vec<Option<PlSmallStr>>,43right_payload_select: Vec<Option<PlSmallStr>>,44left_payload_schema: Arc<Schema>,45right_payload_schema: Arc<Schema>,46args: JoinArgs,47random_state: PlRandomState,48}4950impl EquiJoinParams {51/// Should we emit unmatched rows from the build side?52fn emit_unmatched_build(&self) -> bool {53if self.left_is_build.unwrap() {54self.args.how == JoinType::Left || self.args.how == JoinType::Full55} else {56self.args.how == JoinType::Right || self.args.how == JoinType::Full57}58}5960/// Should we emit unmatched rows from the probe side?61fn emit_unmatched_probe(&self) -> bool {62if self.left_is_build.unwrap() {63self.args.how == JoinType::Right || self.args.how == JoinType::Full64} else {65self.args.how == JoinType::Left || self.args.how == JoinType::Full66}67}68}6970/// A payload selector contains for each column whether that column should be71/// included in the payload, and if yes with what name.72fn compute_payload_selector(73this: &Schema,74other: &Schema,75this_key_schema: &Schema,76other_key_schema: &Schema,77is_left: bool,78args: &JoinArgs,79) -> PolarsResult<Vec<Option<PlSmallStr>>> {80let should_coalesce = args.should_coalesce();8182this.iter_names()83.map(|c| {84#[expect(clippy::never_loop)]85loop {86let selector = if args.how == JoinType::Right {87if is_left {88if should_coalesce && this_key_schema.contains(c) {89// Coalesced to RHS output key.90None91} else {92Some(c.clone())93}94} else if !other.contains(c) || (should_coalesce && other_key_schema.contains(c)) {95Some(c.clone())96} else {97break;98}99} else if should_coalesce && this_key_schema.contains(c) {100if is_left {101Some(c.clone())102} else if args.how == JoinType::Full {103// We must keep the right-hand side keycols around for104// coalescing.105let key_idx = this_key_schema.index_of(c).unwrap();106let name = format_pl_smallstr!("__POLARS_COALESCE_KEYCOL_{key_idx}");107Some(name)108} else {109None110}111} else if !other.contains(c) || is_left {112Some(c.clone())113} else {114break;115};116117return Ok(selector);118}119120let suffixed = format_pl_smallstr!("{}{}", c, args.suffix());121if other.contains(&suffixed) {122polars_bail!(Duplicate: "column with name '{suffixed}' already exists\n\n\123You may want to try:\n\124- renaming the column prior to joining\n\125- using the `suffix` parameter to specify a suffix different to the default one ('_right')")126}127128Ok(Some(suffixed))129})130.collect()131}132133/// Fixes names and does coalescing of columns post-join.134fn postprocess_join(df: DataFrame, params: &EquiJoinParams) -> DataFrame {135if params.args.how == JoinType::Full && params.args.should_coalesce() {136// TODO: don't do string-based column lookups for each dataframe, pre-compute coalesce indices.137let new_cols = df138.columns()139.iter()140.filter_map(|c| {141if let Some(key_idx) = params.left_key_schema.index_of(c.name()) {142let other = df143.column(&format_pl_smallstr!("__POLARS_COALESCE_KEYCOL_{key_idx}"))144.unwrap();145return Some(coalesce_columns(&[c.clone(), other.clone()]).unwrap());146}147148if c.name().starts_with("__POLARS_COALESCE_KEYCOL") {149return None;150}151152Some(c.clone())153})154.collect();155156unsafe { DataFrame::new_unchecked(df.height(), new_cols) }157} else {158df159}160}161162fn select_schema(schema: &Schema, selector: &[Option<PlSmallStr>]) -> Schema {163schema164.iter_fields()165.zip(selector)166.filter_map(|(f, name)| Some(f.with_name(name.clone()?)))167.collect()168}169170async fn select_keys(171df: &DataFrame,172key_selectors: &[StreamExpr],173params: &EquiJoinParams,174state: &ExecutionState,175) -> PolarsResult<HashKeys> {176let mut key_columns = Vec::new();177for selector in key_selectors {178key_columns.push(selector.evaluate(df, state).await?.into_column());179}180let keys = unsafe { DataFrame::new_unchecked_with_broadcast(df.height(), key_columns)? };181Ok(HashKeys::from_df(182&keys,183params.random_state.clone(),184params.args.nulls_equal,185false,186))187}188189fn select_payload(df: DataFrame, selector: &[Option<PlSmallStr>]) -> DataFrame {190let height = df.height();191let new_cols = df192.into_columns()193.into_iter()194.zip(selector)195.filter_map(|(c, name)| Some(c.with_name(name.clone()?)))196.collect();197198unsafe { DataFrame::new_unchecked(height, new_cols) }199}200201fn estimate_cardinality(202morsels: &[Morsel],203key_selectors: &[StreamExpr],204params: &EquiJoinParams,205state: &ExecutionState,206) -> PolarsResult<f64> {207let sample_limit = *JOIN_SAMPLE_LIMIT;208if morsels.is_empty() || sample_limit == 0 {209return Ok(0.0);210}211212let mut total_height = 0;213let mut to_process_end = 0;214while to_process_end < morsels.len() && total_height < sample_limit {215total_height += morsels[to_process_end].df().height();216to_process_end += 1;217}218let last_morsel_idx = to_process_end - 1;219let last_morsel_len = morsels[last_morsel_idx].df().height();220let last_morsel_slice = last_morsel_len - total_height.saturating_sub(sample_limit);221let runtime = get_runtime();222223POOL.install(|| {224let sample_cardinality = morsels[..to_process_end]225.par_iter()226.enumerate()227.try_fold(228CardinalitySketch::new,229|mut sketch, (morsel_idx, morsel)| {230let sliced;231let df = if morsel_idx == last_morsel_idx {232sliced = morsel.df().slice(0, last_morsel_slice);233&sliced234} else {235morsel.df()236};237let hash_keys =238runtime.block_on(select_keys(df, key_selectors, params, state))?;239hash_keys.sketch_cardinality(&mut sketch);240PolarsResult::Ok(sketch)241},242)243.map(|sketch| PolarsResult::Ok(sketch?.estimate()))244.try_reduce_with(|a, b| Ok(a + b))245.unwrap()?;246Ok(sample_cardinality as f64 / total_height.min(sample_limit) as f64)247})248}249250#[derive(Default)]251struct SampleState {252left: Vec<Morsel>,253left_len: usize,254right: Vec<Morsel>,255right_len: usize,256}257258impl SampleState {259async fn sink(260mut recv: PortReceiver,261morsels: &mut Vec<Morsel>,262len: &mut usize,263this_final_len: Arc<RelaxedCell<usize>>,264other_final_len: Arc<RelaxedCell<usize>>,265) -> PolarsResult<()> {266while let Ok(mut morsel) = recv.recv().await {267*len += morsel.df().height();268if *len >= *JOIN_SAMPLE_LIMIT269|| *len270>= other_final_len271.load()272.saturating_mul(LOPSIDED_SAMPLE_FACTOR)273{274morsel.source_token().stop();275}276277drop(morsel.take_consume_token());278morsels.push(morsel);279}280this_final_len.store(*len);281Ok(())282}283284fn try_transition_to_build(285&mut self,286recv: &[PortState],287params: &mut EquiJoinParams,288state: &StreamingExecutionState,289) -> PolarsResult<Option<BuildState>> {290let left_saturated = self.left_len >= *JOIN_SAMPLE_LIMIT;291let right_saturated = self.right_len >= *JOIN_SAMPLE_LIMIT;292let left_done = recv[0] == PortState::Done || left_saturated;293let right_done = recv[1] == PortState::Done || right_saturated;294#[expect(clippy::nonminimal_bool)]295let stop_sampling = (left_done && right_done)296|| (left_done && self.right_len >= LOPSIDED_SAMPLE_FACTOR * self.left_len)297|| (right_done && self.left_len >= LOPSIDED_SAMPLE_FACTOR * self.right_len);298if !stop_sampling {299return Ok(None);300}301302if config::verbose() {303eprintln!(304"choosing build side, sample lengths are: {} vs. {}",305self.left_len, self.right_len306);307}308309let estimate_cardinalities = || {310let left_cardinality = estimate_cardinality(311&self.left,312¶ms.left_key_selectors,313params,314&state.in_memory_exec_state,315)?;316let right_cardinality = estimate_cardinality(317&self.right,318¶ms.right_key_selectors,319params,320&state.in_memory_exec_state,321)?;322if config::verbose() {323eprintln!(324"estimated cardinalities are: {left_cardinality} vs. {right_cardinality}"325);326}327PolarsResult::Ok((left_cardinality, right_cardinality))328};329330let left_is_build = match (left_saturated, right_saturated) {331// Don't bother estimating cardinality, just choose smaller side as332// we have everything in-memory anyway.333(false, false) => self.left_len < self.right_len,334335// Choose the unsaturated side, the saturated side could be336// arbitrarily big.337(false, true) => true,338(true, false) => false,339340(true, true) => {341match params.args.build_side {342Some(JoinBuildSide::PreferLeft) => true,343Some(JoinBuildSide::PreferRight) => false,344Some(JoinBuildSide::ForceLeft | JoinBuildSide::ForceRight) => unreachable!(),345None => {346// Estimate cardinality and choose smaller.347let (lc, rc) = estimate_cardinalities()?;348lc < rc349},350}351},352};353354if config::verbose() {355eprintln!(356"build side chosen: {}",357if left_is_build { "left" } else { "right" }358);359}360361// Transition to building state.362params.left_is_build = Some(left_is_build);363let mut sampled_build_morsels =364BufferedStream::new(core::mem::take(&mut self.left), MorselSeq::default());365let mut sampled_probe_morsels =366BufferedStream::new(core::mem::take(&mut self.right), MorselSeq::default());367if !left_is_build {368core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels);369}370371let partitioner = HashPartitioner::new(state.num_pipelines, 0);372let mut build_state = BuildState::new(373state.num_pipelines,374state.num_pipelines,375sampled_probe_morsels,376);377378// Simulate the sample build morsels flowing into the build side.379if !sampled_build_morsels.is_empty() {380crate::async_executor::task_scope(|scope| {381let mut join_handles = Vec::new();382let receivers = sampled_build_morsels383.reinsert(state.num_pipelines, None, scope, &mut join_handles)384.unwrap();385386for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {387join_handles.push(scope.spawn_task(388TaskPriority::High,389BuildState::partition_and_sink(390recv,391local_builder,392partitioner.clone(),393params,394state,395),396));397}398399polars_io::pl_async::get_runtime().block_on(async move {400for handle in join_handles {401handle.await?;402}403PolarsResult::Ok(())404})405})?;406}407408Ok(Some(build_state))409}410}411412#[derive(Default)]413struct LocalBuilder {414// The complete list of morsels and their computed hashes seen by this builder.415morsels: Vec<(MorselSeq, DataFrame, HashKeys)>,416417// A cardinality sketch per partition for the keys seen by this builder.418sketch_per_p: Vec<CardinalitySketch>,419420// morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i]421// for partition p, where start, stop are:422// let start = morsel_idxs_offsets[i * num_partitions + p];423// let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p];424morsel_idxs_values_per_p: Vec<Vec<IdxSize>>,425morsel_idxs_offsets_per_p: Vec<usize>,426}427428struct BuildState {429local_builders: Vec<LocalBuilder>,430sampled_probe_morsels: BufferedStream,431}432433impl BuildState {434fn new(435num_pipelines: usize,436num_partitions: usize,437sampled_probe_morsels: BufferedStream,438) -> Self {439let local_builders = (0..num_pipelines)440.map(|_| LocalBuilder {441morsels: Vec::new(),442sketch_per_p: vec![CardinalitySketch::default(); num_partitions],443morsel_idxs_values_per_p: vec![Vec::new(); num_partitions],444morsel_idxs_offsets_per_p: vec![0; num_partitions],445})446.collect();447Self {448local_builders,449sampled_probe_morsels,450}451}452453async fn partition_and_sink(454mut recv: PortReceiver,455local: &mut LocalBuilder,456partitioner: HashPartitioner,457params: &EquiJoinParams,458state: &StreamingExecutionState,459) -> PolarsResult<()> {460let track_unmatchable = params.emit_unmatched_build();461let (key_selectors, payload_selector);462if params.left_is_build.unwrap() {463payload_selector = ¶ms.left_payload_select;464key_selectors = ¶ms.left_key_selectors;465} else {466payload_selector = ¶ms.right_payload_select;467key_selectors = ¶ms.right_key_selectors;468};469470while let Ok(morsel) = recv.recv().await {471// Compute hashed keys and payload. We must rechunk the payload for472// later gathers.473let hash_keys = select_keys(474morsel.df(),475key_selectors,476params,477&state.in_memory_exec_state,478)479.await?;480let mut payload = select_payload(morsel.df().clone(), payload_selector);481payload.rechunk_mut();482483hash_keys.gen_idxs_per_partition(484&partitioner,485&mut local.morsel_idxs_values_per_p,486&mut local.sketch_per_p,487track_unmatchable,488);489490local491.morsel_idxs_offsets_per_p492.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len()));493local.morsels.push((morsel.seq(), payload, hash_keys));494}495Ok(())496}497498fn finalize_ordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState {499let track_unmatchable = params.emit_unmatched_build();500let payload_schema = if params.left_is_build.unwrap() {501¶ms.left_payload_schema502} else {503¶ms.right_payload_schema504};505506let num_partitions = self.local_builders[0].sketch_per_p.len();507let local_builders = &self.local_builders;508let probe_tables: SparseInitVec<ProbeTable> = SparseInitVec::with_capacity(num_partitions);509510POOL.scope(|s| {511for p in 0..num_partitions {512let probe_tables = &probe_tables;513s.spawn(move |_| {514// TODO: every thread does an identical linearize, we can do a single parallel one.515let mut kmerge = BinaryHeap::with_capacity(local_builders.len());516let mut cur_idx_per_loc = vec![0; local_builders.len()];517518// Compute cardinality estimate and total amount of519// payload for this partition, and initialize k-way merge.520let mut sketch = CardinalitySketch::new();521let mut payload_rows = 0;522for (l_idx, l) in local_builders.iter().enumerate() {523let Some((seq, _, _)) = l.morsels.first() else {524continue;525};526kmerge.push(Priority(Reverse(seq), l_idx));527528sketch.combine(&l.sketch_per_p[p]);529let offsets_len = l.morsel_idxs_offsets_per_p.len();530payload_rows +=531l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p];532}533534// Allocate hash table and payload builder.535let mut p_table = table.new_empty();536p_table.reserve(sketch.estimate() * 5 / 4);537let mut p_payload = DataFrameBuilder::new(payload_schema.clone());538p_payload.reserve(payload_rows);539540let mut p_seq_ids = Vec::new();541if track_unmatchable {542p_seq_ids.reserve(payload_rows);543}544545// Linearize and build.546unsafe {547let mut norm_seq_id = 0 as IdxSize;548while let Some(Priority(Reverse(_seq), l_idx)) = kmerge.pop() {549let l = local_builders.get_unchecked(l_idx);550let idx_in_l = *cur_idx_per_loc.get_unchecked(l_idx);551*cur_idx_per_loc.get_unchecked_mut(l_idx) += 1;552if let Some((next_seq, _, _)) = l.morsels.get(idx_in_l + 1) {553kmerge.push(Priority(Reverse(next_seq), l_idx));554}555556let (_mseq, payload, keys) = l.morsels.get_unchecked(idx_in_l);557let p_morsel_idxs_start =558l.morsel_idxs_offsets_per_p[idx_in_l * num_partitions + p];559let p_morsel_idxs_stop =560l.morsel_idxs_offsets_per_p[(idx_in_l + 1) * num_partitions + p];561let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]562[p_morsel_idxs_start..p_morsel_idxs_stop];563p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable);564p_payload.gather_extend(payload, p_morsel_idxs, ShareStrategy::Never);565566if track_unmatchable {567p_seq_ids.resize(p_payload.len(), norm_seq_id);568norm_seq_id += 1;569}570}571}572573probe_tables574.try_set(575p,576ProbeTable {577hash_table: p_table,578payload: p_payload.freeze(),579seq_ids: p_seq_ids,580},581)582.ok()583.unwrap();584});585}586});587588ProbeState {589table_per_partition: probe_tables.try_assume_init().ok().unwrap(),590max_seq_sent: MorselSeq::default(),591sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels),592unordered_morsel_seq: AtomicU64::new(0),593}594}595596fn finalize_unordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState {597let track_unmatchable = params.emit_unmatched_build();598let payload_schema = if params.left_is_build.unwrap() {599¶ms.left_payload_schema600} else {601¶ms.right_payload_schema602};603604// To reduce maximum memory usage we want to drop the morsels605// as soon as they're processed, so we move into Arcs. The drops might606// also be expensive, so instead of directly dropping we put that on607// a work queue.608let morsels_per_local_builder = self609.local_builders610.iter_mut()611.map(|b| Arc::new(core::mem::take(&mut b.morsels)))612.collect_vec();613let (morsel_drop_q_send, morsel_drop_q_recv) =614async_channel::bounded(morsels_per_local_builder.len());615let num_partitions = self.local_builders[0].sketch_per_p.len();616let local_builders = &self.local_builders;617let probe_tables: SparseInitVec<ProbeTable> = SparseInitVec::with_capacity(num_partitions);618619async_executor::task_scope(|s| {620// Wrap in outer Arc to move to each thread, performing the621// expensive clone on that thread.622let arc_morsels_per_local_builder = Arc::new(morsels_per_local_builder);623let mut join_handles = Vec::new();624for p in 0..num_partitions {625let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder);626let morsel_drop_q_send = morsel_drop_q_send.clone();627let morsel_drop_q_recv = morsel_drop_q_recv.clone();628let probe_tables = &probe_tables;629join_handles.push(s.spawn_task(TaskPriority::High, async move {630// Extract from outer arc and drop outer arc.631let morsels_per_local_builder =632Arc::unwrap_or_clone(arc_morsels_per_local_builder);633634// Compute cardinality estimate and total amount of635// payload for this partition.636let mut sketch = CardinalitySketch::new();637let mut payload_rows = 0;638for l in local_builders {639sketch.combine(&l.sketch_per_p[p]);640let offsets_len = l.morsel_idxs_offsets_per_p.len();641payload_rows +=642l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p];643}644645// Allocate hash table and payload builder.646let mut p_table = table.new_empty();647p_table.reserve(sketch.estimate() * 5 / 4);648let mut p_payload = DataFrameBuilder::new(payload_schema.clone());649p_payload.reserve(payload_rows);650651// Build.652let mut skip_drop_attempt = false;653for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) {654// Try to help with dropping the processed morsels.655if !skip_drop_attempt {656drop(morsel_drop_q_recv.try_recv());657}658659for (i, morsel) in l_morsels.iter().enumerate() {660let (_mseq, payload, keys) = morsel;661unsafe {662let p_morsel_idxs_start =663l.morsel_idxs_offsets_per_p[i * num_partitions + p];664let p_morsel_idxs_stop =665l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p];666let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]667[p_morsel_idxs_start..p_morsel_idxs_stop];668p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable);669p_payload.gather_extend(670payload,671p_morsel_idxs,672ShareStrategy::Never,673);674}675}676677if let Some(l) = Arc::into_inner(l_morsels) {678// If we're the last thread to process this set of morsels we're probably679// falling behind the rest, since the drop can be quite expensive we skip680// a drop attempt hoping someone else will pick up the slack.681drop(morsel_drop_q_send.try_send(l));682skip_drop_attempt = true;683} else {684skip_drop_attempt = false;685}686}687688// We're done, help others out by doing drops.689drop(morsel_drop_q_send); // So we don't deadlock trying to receive from ourselves.690while let Ok(l_morsels) = morsel_drop_q_recv.recv().await {691drop(l_morsels);692}693694probe_tables695.try_set(696p,697ProbeTable {698hash_table: p_table,699payload: p_payload.freeze(),700seq_ids: Vec::new(),701},702)703.ok()704.unwrap();705}));706}707708// Drop outer arc after spawning each thread so the inner arcs709// can get dropped as soon as they're processed. We also have to710// drop the drop queue sender so we don't deadlock waiting for it711// to end.712drop(arc_morsels_per_local_builder);713drop(morsel_drop_q_send);714715polars_io::pl_async::get_runtime().block_on(async move {716for handle in join_handles {717handle.await;718}719});720});721722ProbeState {723table_per_partition: probe_tables.try_assume_init().ok().unwrap(),724max_seq_sent: MorselSeq::default(),725sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels),726unordered_morsel_seq: AtomicU64::new(0),727}728}729}730731struct ProbeTable {732hash_table: Box<dyn IdxTable>,733payload: DataFrame,734seq_ids: Vec<IdxSize>,735}736737struct ProbeState {738table_per_partition: Vec<ProbeTable>,739max_seq_sent: MorselSeq,740sampled_probe_morsels: BufferedStream,741742// For unordered joins we relabel output morsels to speed up the linearizer.743unordered_morsel_seq: AtomicU64,744}745746impl ProbeState {747/// Returns the max morsel sequence sent.748async fn partition_and_probe(749mut recv: PortReceiver,750mut send: PortSender,751partitions: &[ProbeTable],752unordered_morsel_seq: &AtomicU64,753partitioner: HashPartitioner,754params: &EquiJoinParams,755state: &StreamingExecutionState,756) -> PolarsResult<MorselSeq> {757// TODO: shuffle after partitioning and keep probe tables thread-local.758let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()];759let mut probe_partitions = Vec::new();760let mut materialized_idxsize_range = Vec::new();761let mut table_match = Vec::new();762let mut probe_match = Vec::new();763let mut max_seq = MorselSeq::default();764765let probe_limit = get_ideal_morsel_size() as IdxSize;766let mark_matches = params.emit_unmatched_build();767let emit_unmatched = params.emit_unmatched_probe();768769let (key_selectors, payload_selector, build_payload_schema, probe_payload_schema);770if params.left_is_build.unwrap() {771key_selectors = ¶ms.right_key_selectors;772payload_selector = ¶ms.right_payload_select;773build_payload_schema = ¶ms.left_payload_schema;774probe_payload_schema = ¶ms.right_payload_schema;775} else {776key_selectors = ¶ms.left_key_selectors;777payload_selector = ¶ms.left_payload_select;778build_payload_schema = ¶ms.right_payload_schema;779probe_payload_schema = ¶ms.left_payload_schema;780};781782let mut build_out = DataFrameBuilder::new(build_payload_schema.clone());783let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone());784785// A simple estimate used to size reserves.786let mut selectivity_estimate = 1.0;787let mut selectivity_estimate_confidence = 0.0;788789while let Ok(morsel) = recv.recv().await {790// Compute hashed keys and payload.791let (df, in_seq, src_token, wait_token) = morsel.into_inner();792793let df_height = df.height();794if df_height == 0 {795continue;796}797798let hash_keys =799select_keys(&df, key_selectors, params, &state.in_memory_exec_state).await?;800let mut payload = select_payload(df, payload_selector);801let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches.802let mut total_matches = 0;803804// Use selectivity estimate to reserve for morsel builders.805let max_match_per_key_est = (selectivity_estimate * 1.2) as usize + 16;806let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize)807.min(probe_limit as usize);808build_out.reserve(out_est_size + max_match_per_key_est);809810unsafe {811let mut new_morsel =812|build: &mut DataFrameBuilder, probe: &mut DataFrameBuilder| {813let mut build_df = build.freeze_reset();814let mut probe_df = probe.freeze_reset();815let out_df = if params.left_is_build.unwrap() {816build_df.hstack_mut_unchecked(probe_df.columns());817build_df818} else {819probe_df.hstack_mut_unchecked(build_df.columns());820probe_df821};822let out_df = postprocess_join(out_df, params);823let out_seq = if params.preserve_order_probe {824in_seq825} else {826MorselSeq::new(unordered_morsel_seq.fetch_add(1, Ordering::Relaxed))827};828max_seq = out_seq;829Morsel::new(out_df, out_seq, src_token.clone())830};831832if params.preserve_order_probe {833// To preserve the order we can't do bulk probes per partition and must follow834// the order of the probe morsel. We can still group probes that are835// consecutively on the same partition.836probe_partitions.clear();837hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched);838839let mut probe_group_start = 0;840while probe_group_start < probe_partitions.len() {841let p_idx = probe_partitions[probe_group_start];842let mut probe_group_end = probe_group_start + 1;843while probe_partitions.get(probe_group_end) == Some(&p_idx) {844probe_group_end += 1;845}846let Some(p) = partitions.get(p_idx as usize) else {847probe_group_start = probe_group_end;848continue;849};850851materialized_idxsize_range.extend(852materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize,853);854855while probe_group_start < probe_group_end {856let matches_before_limit = probe_limit - probe_match.len() as IdxSize;857table_match.clear();858probe_group_start += p.hash_table.probe_subset(859&hash_keys,860&materialized_idxsize_range[probe_group_start..probe_group_end],861&mut table_match,862&mut probe_match,863mark_matches,864emit_unmatched,865matches_before_limit,866) as usize;867868if emit_unmatched {869build_out.opt_gather_extend(870&p.payload,871&table_match,872ShareStrategy::Always,873);874} else {875build_out.gather_extend(876&p.payload,877&table_match,878ShareStrategy::Always,879);880};881882if probe_match.len() >= probe_limit as usize883|| probe_group_start == probe_partitions.len()884{885if !payload_rechunked {886payload.rechunk_mut();887payload_rechunked = true;888}889probe_out.gather_extend(890&payload,891&probe_match,892ShareStrategy::Always,893);894let out_len = probe_match.len();895probe_match.clear();896let out_morsel = new_morsel(&mut build_out, &mut probe_out);897if send.send(out_morsel).await.is_err() {898return Ok(max_seq);899}900if probe_group_end != probe_partitions.len() {901// We had enough matches to need a mid-partition flush, let's assume there are a lot of902// matches and just do a large reserve.903let old_est = probe_limit as usize + max_match_per_key_est;904build_out.reserve(old_est.max(out_len + 16));905}906}907}908}909} else {910// Partition and probe the tables.911for p in partition_idxs.iter_mut() {912p.clear();913}914hash_keys.gen_idxs_per_partition(915&partitioner,916&mut partition_idxs,917&mut [],918emit_unmatched,919);920921for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) {922let mut offset = 0;923while offset < idxs_in_p.len() {924let matches_before_limit = probe_limit - probe_match.len() as IdxSize;925table_match.clear();926offset += p.hash_table.probe_subset(927&hash_keys,928&idxs_in_p[offset..],929&mut table_match,930&mut probe_match,931mark_matches,932emit_unmatched,933matches_before_limit,934) as usize;935936if table_match.is_empty() {937continue;938}939total_matches += table_match.len();940941if emit_unmatched {942build_out.opt_gather_extend(943&p.payload,944&table_match,945ShareStrategy::Always,946);947} else {948build_out.gather_extend(949&p.payload,950&table_match,951ShareStrategy::Always,952);953};954955if probe_match.len() >= probe_limit as usize {956if !payload_rechunked {957payload.rechunk_mut();958payload_rechunked = true;959}960probe_out.gather_extend(961&payload,962&probe_match,963ShareStrategy::Always,964);965let out_len = probe_match.len();966probe_match.clear();967let out_morsel = new_morsel(&mut build_out, &mut probe_out);968if send.send(out_morsel).await.is_err() {969return Ok(max_seq);970}971// We had enough matches to need a mid-partition flush, let's assume there are a lot of972// matches and just do a large reserve.973let old_est = probe_limit as usize + max_match_per_key_est;974build_out.reserve(old_est.max(out_len + 16));975}976}977}978}979980if !probe_match.is_empty() {981if !payload_rechunked {982payload.rechunk_mut();983}984probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always);985probe_match.clear();986let out_morsel = new_morsel(&mut build_out, &mut probe_out);987if send.send(out_morsel).await.is_err() {988return Ok(max_seq);989}990}991}992993drop(wait_token);994995// Move selectivity estimate a bit towards latest value. Allows rapid changes at first.996// TODO: implement something more re-usable and robust.997selectivity_estimate = selectivity_estimate_confidence * selectivity_estimate998+ (1.0 - selectivity_estimate_confidence)999* (total_matches as f64 / df_height as f64);1000selectivity_estimate_confidence = (selectivity_estimate_confidence + 0.1).min(0.8);1001}10021003Ok(max_seq)1004}10051006fn ordered_unmatched(&mut self, params: &EquiJoinParams) -> DataFrame {1007// TODO: parallelize this operator.10081009let build_payload_schema = if params.left_is_build.unwrap() {1010¶ms.left_payload_schema1011} else {1012¶ms.right_payload_schema1013};10141015let mut unmarked_idxs = Vec::new();1016let mut linearized_idxs = Vec::new();10171018for (p_idx, p) in self.table_per_partition.iter().enumerate_idx() {1019p.hash_table1020.unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX);1021linearized_idxs.extend(1022unmarked_idxs1023.iter()1024.map(|i| (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i)),1025);1026}10271028linearized_idxs.sort_by_key(|(seq_id, _, _)| *seq_id);10291030unsafe {1031let mut build_out = DataFrameBuilder::new(build_payload_schema.clone());1032build_out.reserve(linearized_idxs.len());10331034// Group indices from the same partition.1035let mut group_start = 0;1036let mut gather_idxs = Vec::new();1037while group_start < linearized_idxs.len() {1038gather_idxs.clear();10391040let (_seq, p_idx, idx_in_p) = linearized_idxs[group_start];1041gather_idxs.push(idx_in_p);1042let mut group_end = group_start + 1;1043while group_end < linearized_idxs.len() && linearized_idxs[group_end].1 == p_idx {1044gather_idxs.push(linearized_idxs[group_end].2);1045group_end += 1;1046}10471048build_out.gather_extend(1049&self.table_per_partition[p_idx as usize].payload,1050&gather_idxs,1051ShareStrategy::Never, // Don't keep entire table alive for unmatched indices.1052);10531054group_start = group_end;1055}10561057let mut build_df = build_out.freeze();1058let out_df = if params.left_is_build.unwrap() {1059let probe_df =1060DataFrame::full_null(¶ms.right_payload_schema, build_df.height());1061build_df.hstack_mut_unchecked(probe_df.columns());1062build_df1063} else {1064let mut probe_df =1065DataFrame::full_null(¶ms.left_payload_schema, build_df.height());1066probe_df.hstack_mut_unchecked(build_df.columns());1067probe_df1068};1069postprocess_join(out_df, params)1070}1071}1072}10731074impl Drop for ProbeState {1075fn drop(&mut self) {1076POOL.install(|| {1077// Parallel drop as the state might be quite big.1078self.table_per_partition.par_drain(..).for_each(drop);1079})1080}1081}10821083struct EmitUnmatchedState {1084partitions: Vec<ProbeTable>,1085active_partition_idx: usize,1086offset_in_active_p: usize,1087morsel_seq: MorselSeq,1088}10891090impl EmitUnmatchedState {1091async fn emit_unmatched(1092&mut self,1093mut send: PortSender,1094params: &EquiJoinParams,1095num_pipelines: usize,1096) -> PolarsResult<()> {1097let total_len: usize = self1098.partitions1099.iter()1100.map(|p| p.hash_table.num_keys() as usize)1101.sum();1102let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1);1103let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines);1104let morsel_size = total_len.div_ceil(morsel_count).max(1);11051106let wait_group = WaitGroup::default();1107let source_token = SourceToken::new();1108let mut unmarked_idxs = Vec::new();1109while let Some(p) = self.partitions.get(self.active_partition_idx) {1110loop {1111// Generate a chunk of unmarked key indices.1112self.offset_in_active_p += p.hash_table.unmarked_keys(1113&mut unmarked_idxs,1114self.offset_in_active_p as IdxSize,1115morsel_size as IdxSize,1116) as usize;1117if unmarked_idxs.is_empty() {1118break;1119}11201121// Gather and create full-null counterpart.1122let out_df = unsafe {1123let mut build_df = p.payload.take_slice_unchecked_impl(&unmarked_idxs, false);1124let len = build_df.height();1125if params.left_is_build.unwrap() {1126let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len);1127build_df.hstack_mut_unchecked(probe_df.columns());1128build_df1129} else {1130let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, len);1131probe_df.hstack_mut_unchecked(build_df.columns());1132probe_df1133}1134};1135let out_df = postprocess_join(out_df, params);11361137// Send and wait until consume token is consumed.1138let mut morsel = Morsel::new(out_df, self.morsel_seq, source_token.clone());1139self.morsel_seq = self.morsel_seq.successor();1140morsel.set_consume_token(wait_group.token());1141if send.send(morsel).await.is_err() {1142return Ok(());1143}11441145wait_group.wait().await;1146if source_token.stop_requested() {1147return Ok(());1148}1149}11501151self.active_partition_idx += 1;1152self.offset_in_active_p = 0;1153}11541155Ok(())1156}1157}11581159enum EquiJoinState {1160Sample(SampleState),1161Build(BuildState),1162Probe(ProbeState),1163EmitUnmatchedBuild(EmitUnmatchedState),1164EmitUnmatchedBuildInOrder(InMemorySourceNode),1165Done,1166}11671168pub struct EquiJoinNode {1169state: EquiJoinState,1170params: EquiJoinParams,1171table: Box<dyn IdxTable>,1172}11731174impl EquiJoinNode {1175#[allow(clippy::too_many_arguments)]1176pub fn new(1177left_input_schema: Arc<Schema>,1178right_input_schema: Arc<Schema>,1179left_key_schema: Arc<Schema>,1180right_key_schema: Arc<Schema>,1181unique_key_schema: Arc<Schema>,1182left_key_selectors: Vec<StreamExpr>,1183right_key_selectors: Vec<StreamExpr>,1184args: JoinArgs,1185num_pipelines: usize,1186) -> PolarsResult<Self> {1187let left_is_build = match args.maintain_order {1188MaintainOrderJoin::None => match args.build_side {1189Some(JoinBuildSide::ForceLeft) => Some(true),1190Some(JoinBuildSide::ForceRight) => Some(false),1191Some(JoinBuildSide::PreferLeft) | Some(JoinBuildSide::PreferRight) | None => {1192if *JOIN_SAMPLE_LIMIT == 0 {1193Some(args.build_side != Some(JoinBuildSide::PreferRight))1194} else {1195None1196}1197},1198},1199MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => {1200if args.build_side == Some(JoinBuildSide::ForceLeft) {1201polars_warn!("can't force left build-side with left-maintaining cross-join");1202}1203Some(false)1204},1205MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => {1206if args.build_side == Some(JoinBuildSide::ForceRight) {1207polars_warn!("can't force right build-side with right-maintaining cross-join");1208}1209Some(true)1210},1211};12121213let preserve_order_probe = args.maintain_order != MaintainOrderJoin::None;1214let preserve_order_build = matches!(1215args.maintain_order,1216MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft1217);12181219let left_payload_select = compute_payload_selector(1220&left_input_schema,1221&right_input_schema,1222&left_key_schema,1223&right_key_schema,1224true,1225&args,1226)?;1227let right_payload_select = compute_payload_selector(1228&right_input_schema,1229&left_input_schema,1230&right_key_schema,1231&left_key_schema,1232false,1233&args,1234)?;12351236let state = if left_is_build.is_some() {1237EquiJoinState::Build(BuildState::new(1238num_pipelines,1239num_pipelines,1240BufferedStream::default(),1241))1242} else {1243EquiJoinState::Sample(SampleState::default())1244};12451246let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select));1247let right_payload_schema =1248Arc::new(select_schema(&right_input_schema, &right_payload_select));1249Ok(Self {1250state,1251params: EquiJoinParams {1252left_is_build,1253preserve_order_build,1254preserve_order_probe,1255left_key_schema,1256left_key_selectors,1257right_key_schema,1258right_key_selectors,1259left_payload_select,1260right_payload_select,1261left_payload_schema,1262right_payload_schema,1263args,1264random_state: PlRandomState::default(),1265},1266table: new_idx_table(unique_key_schema),1267})1268}1269}12701271impl ComputeNode for EquiJoinNode {1272fn name(&self) -> &str {1273"equi-join"1274}12751276fn update_state(1277&mut self,1278recv: &mut [PortState],1279send: &mut [PortState],1280state: &StreamingExecutionState,1281) -> PolarsResult<()> {1282assert!(recv.len() == 2 && send.len() == 1);12831284// If the output doesn't want any more data, transition to being done.1285if send[0] == PortState::Done {1286self.state = EquiJoinState::Done;1287}12881289// If we are sampling and both sides are done/filled, transition to building.1290if let EquiJoinState::Sample(sample_state) = &mut self.state {1291if let Some(build_state) =1292sample_state.try_transition_to_build(recv, &mut self.params, state)?1293{1294self.state = EquiJoinState::Build(build_state);1295}1296}12971298let build_idx = if self.params.left_is_build == Some(true) {129901300} else {130111302};1303let probe_idx = 1 - build_idx;13041305// If we are building and the build input is done, transition to probing.1306if let EquiJoinState::Build(build_state) = &mut self.state {1307if recv[build_idx] == PortState::Done {1308let probe_state = if self.params.preserve_order_build {1309build_state.finalize_ordered(&self.params, &*self.table)1310} else {1311build_state.finalize_unordered(&self.params, &*self.table)1312};1313self.state = EquiJoinState::Probe(probe_state);1314}1315}13161317// If we are probing and the probe input is done, emit unmatched if1318// necessary, otherwise we're done.1319if let EquiJoinState::Probe(probe_state) = &mut self.state {1320let samples_consumed = probe_state.sampled_probe_morsels.is_empty();1321if samples_consumed && recv[probe_idx] == PortState::Done {1322if self.params.emit_unmatched_build() {1323if self.params.preserve_order_build {1324let unmatched = probe_state.ordered_unmatched(&self.params);1325let src = InMemorySourceNode::new(1326Arc::new(unmatched),1327probe_state.max_seq_sent.successor(),1328);1329self.state = EquiJoinState::EmitUnmatchedBuildInOrder(src);1330} else {1331self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState {1332partitions: core::mem::take(&mut probe_state.table_per_partition),1333active_partition_idx: 0,1334offset_in_active_p: 0,1335morsel_seq: probe_state.max_seq_sent.successor(),1336});1337}1338} else {1339self.state = EquiJoinState::Done;1340}1341}1342}13431344// Finally, check if we are done emitting unmatched keys.1345if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state {1346if emit_state.active_partition_idx >= emit_state.partitions.len() {1347self.state = EquiJoinState::Done;1348}1349}13501351match &mut self.state {1352EquiJoinState::Sample(sample_state) => {1353send[0] = PortState::Blocked;1354if recv[0] != PortState::Done {1355recv[0] = if sample_state.left_len < *JOIN_SAMPLE_LIMIT {1356PortState::Ready1357} else {1358PortState::Blocked1359};1360}1361if recv[1] != PortState::Done {1362recv[1] = if sample_state.right_len < *JOIN_SAMPLE_LIMIT {1363PortState::Ready1364} else {1365PortState::Blocked1366};1367}1368},1369EquiJoinState::Build(_) => {1370send[0] = PortState::Blocked;1371if recv[build_idx] != PortState::Done {1372recv[build_idx] = PortState::Ready;1373}1374if recv[probe_idx] != PortState::Done {1375recv[probe_idx] = PortState::Blocked;1376}1377},1378EquiJoinState::Probe(probe_state) => {1379if recv[probe_idx] != PortState::Done {1380core::mem::swap(&mut send[0], &mut recv[probe_idx]);1381} else {1382let samples_consumed = probe_state.sampled_probe_morsels.is_empty();1383send[0] = if samples_consumed {1384PortState::Done1385} else {1386PortState::Ready1387};1388}1389recv[build_idx] = PortState::Done;1390},1391EquiJoinState::EmitUnmatchedBuild(_) => {1392send[0] = PortState::Ready;1393recv[build_idx] = PortState::Done;1394recv[probe_idx] = PortState::Done;1395},1396EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => {1397recv[build_idx] = PortState::Done;1398recv[probe_idx] = PortState::Done;1399src_node.update_state(&mut [], &mut send[0..1], state)?;1400if send[0] == PortState::Done {1401self.state = EquiJoinState::Done;1402}1403},1404EquiJoinState::Done => {1405send[0] = PortState::Done;1406recv[0] = PortState::Done;1407recv[1] = PortState::Done;1408},1409}1410Ok(())1411}14121413fn is_memory_intensive_pipeline_blocker(&self) -> bool {1414matches!(1415self.state,1416EquiJoinState::Sample { .. } | EquiJoinState::Build { .. }1417)1418}14191420fn spawn<'env, 's>(1421&'env mut self,1422scope: &'s TaskScope<'s, 'env>,1423recv_ports: &mut [Option<RecvPort<'_>>],1424send_ports: &mut [Option<SendPort<'_>>],1425state: &'s StreamingExecutionState,1426join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,1427) {1428assert!(recv_ports.len() == 2);1429assert!(send_ports.len() == 1);14301431let build_idx = if self.params.left_is_build == Some(true) {143201433} else {143411435};1436let probe_idx = 1 - build_idx;14371438match &mut self.state {1439EquiJoinState::Sample(sample_state) => {1440assert!(send_ports[0].is_none());1441let left_final_len = Arc::new(RelaxedCell::from(if recv_ports[0].is_none() {1442sample_state.left_len1443} else {1444usize::MAX1445}));1446let right_final_len = Arc::new(RelaxedCell::from(if recv_ports[1].is_none() {1447sample_state.right_len1448} else {1449usize::MAX1450}));14511452if let Some(left_recv) = recv_ports[0].take() {1453join_handles.push(scope.spawn_task(1454TaskPriority::High,1455SampleState::sink(1456left_recv.serial(),1457&mut sample_state.left,1458&mut sample_state.left_len,1459left_final_len.clone(),1460right_final_len.clone(),1461),1462));1463}1464if let Some(right_recv) = recv_ports[1].take() {1465join_handles.push(scope.spawn_task(1466TaskPriority::High,1467SampleState::sink(1468right_recv.serial(),1469&mut sample_state.right,1470&mut sample_state.right_len,1471right_final_len,1472left_final_len,1473),1474));1475}1476},1477EquiJoinState::Build(build_state) => {1478assert!(send_ports[0].is_none());1479assert!(recv_ports[probe_idx].is_none());1480let receivers = recv_ports[build_idx].take().unwrap().parallel();14811482let partitioner = HashPartitioner::new(state.num_pipelines, 0);1483for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {1484join_handles.push(scope.spawn_task(1485TaskPriority::High,1486BuildState::partition_and_sink(1487recv,1488local_builder,1489partitioner.clone(),1490&self.params,1491state,1492),1493));1494}1495},1496EquiJoinState::Probe(probe_state) => {1497assert!(recv_ports[build_idx].is_none());1498let senders = send_ports[0].take().unwrap().parallel();1499let receivers = probe_state1500.sampled_probe_morsels1501.reinsert(1502state.num_pipelines,1503recv_ports[probe_idx].take(),1504scope,1505join_handles,1506)1507.unwrap();15081509let partitioner = HashPartitioner::new(state.num_pipelines, 0);1510let probe_tasks = receivers1511.into_iter()1512.zip(senders)1513.map(|(recv, send)| {1514scope.spawn_task(1515TaskPriority::High,1516ProbeState::partition_and_probe(1517recv,1518send,1519&probe_state.table_per_partition,1520&probe_state.unordered_morsel_seq,1521partitioner.clone(),1522&self.params,1523state,1524),1525)1526})1527.collect_vec();15281529let max_seq_sent = &mut probe_state.max_seq_sent;1530join_handles.push(scope.spawn_task(TaskPriority::High, async move {1531for probe_task in probe_tasks {1532*max_seq_sent = (*max_seq_sent).max(probe_task.await?);1533}1534Ok(())1535}));1536},1537EquiJoinState::EmitUnmatchedBuild(emit_state) => {1538assert!(recv_ports[build_idx].is_none());1539assert!(recv_ports[probe_idx].is_none());1540let send = send_ports[0].take().unwrap().serial();1541join_handles.push(scope.spawn_task(1542TaskPriority::Low,1543emit_state.emit_unmatched(send, &self.params, state.num_pipelines),1544));1545},1546EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => {1547assert!(recv_ports[build_idx].is_none());1548assert!(recv_ports[probe_idx].is_none());1549src_node.spawn(scope, &mut [], send_ports, state, join_handles);1550},1551EquiJoinState::Done => unreachable!(),1552}1553}1554}155515561557