Path: blob/main/crates/polars-stream/src/nodes/joins/equi_join.rs
6939 views
use std::cmp::Reverse;1use std::collections::BinaryHeap;2use std::sync::Arc;3use std::sync::atomic::{AtomicU64, Ordering};45use arrow::array::builder::ShareStrategy;6use polars_core::frame::builder::DataFrameBuilder;7use polars_core::prelude::*;8use polars_core::schema::{Schema, SchemaExt};9use polars_core::{POOL, config};10use polars_expr::hash_keys::HashKeys;11use polars_expr::idx_table::{IdxTable, new_idx_table};12use polars_io::pl_async::get_runtime;13use polars_ops::frame::{JoinArgs, JoinType, MaintainOrderJoin};14use polars_ops::series::coalesce_columns;15use polars_utils::cardinality_sketch::CardinalitySketch;16use polars_utils::hashing::HashPartitioner;17use polars_utils::itertools::Itertools;18use polars_utils::pl_str::PlSmallStr;19use polars_utils::priority::Priority;20use polars_utils::relaxed_cell::RelaxedCell;21use polars_utils::sparse_init_vec::SparseInitVec;22use polars_utils::{IdxSize, format_pl_smallstr};23use rayon::prelude::*;2425use super::{BufferedStream, JOIN_SAMPLE_LIMIT, LOPSIDED_SAMPLE_FACTOR};26use crate::async_executor;27use crate::async_primitives::connector::{Receiver, Sender};28use crate::async_primitives::wait_group::WaitGroup;29use crate::expression::StreamExpr;30use crate::morsel::{SourceToken, get_ideal_morsel_size};31use crate::nodes::compute_node_prelude::*;32use crate::nodes::in_memory_source::InMemorySourceNode;3334struct EquiJoinParams {35left_is_build: Option<bool>,36preserve_order_build: bool,37preserve_order_probe: bool,38left_key_schema: Arc<Schema>,39left_key_selectors: Vec<StreamExpr>,40#[allow(dead_code)]41right_key_schema: Arc<Schema>,42right_key_selectors: Vec<StreamExpr>,43left_payload_select: Vec<Option<PlSmallStr>>,44right_payload_select: Vec<Option<PlSmallStr>>,45left_payload_schema: Arc<Schema>,46right_payload_schema: Arc<Schema>,47args: JoinArgs,48random_state: PlRandomState,49}5051impl EquiJoinParams {52/// Should we emit unmatched rows from the build side?53fn emit_unmatched_build(&self) -> bool {54if self.left_is_build.unwrap() {55self.args.how == JoinType::Left || self.args.how == JoinType::Full56} else {57self.args.how == JoinType::Right || self.args.how == JoinType::Full58}59}6061/// Should we emit unmatched rows from the probe side?62fn emit_unmatched_probe(&self) -> bool {63if self.left_is_build.unwrap() {64self.args.how == JoinType::Right || self.args.how == JoinType::Full65} else {66self.args.how == JoinType::Left || self.args.how == JoinType::Full67}68}69}7071/// A payload selector contains for each column whether that column should be72/// included in the payload, and if yes with what name.73fn compute_payload_selector(74this: &Schema,75other: &Schema,76this_key_schema: &Schema,77other_key_schema: &Schema,78is_left: bool,79args: &JoinArgs,80) -> PolarsResult<Vec<Option<PlSmallStr>>> {81let should_coalesce = args.should_coalesce();8283let mut coalesce_idx = 0;84this.iter_names()85.map(|c| {86#[expect(clippy::never_loop)]87loop {88let selector = if args.how == JoinType::Right {89if is_left {90if should_coalesce && this_key_schema.contains(c) {91// Coalesced to RHS output key.92None93} else {94Some(c.clone())95}96} else if !other.contains(c) || (should_coalesce && other_key_schema.contains(c)) {97Some(c.clone())98} else {99break;100}101} else if should_coalesce && this_key_schema.contains(c) {102if is_left {103Some(c.clone())104} else if args.how == JoinType::Full {105// We must keep the right-hand side keycols around for106// coalescing.107let name = format_pl_smallstr!("__POLARS_COALESCE_KEYCOL{coalesce_idx}");108coalesce_idx += 1;109Some(name)110} else {111None112}113} else if !other.contains(c) || is_left {114Some(c.clone())115} else {116break;117};118119return Ok(selector);120}121122let suffixed = format_pl_smallstr!("{}{}", c, args.suffix());123if other.contains(&suffixed) {124polars_bail!(Duplicate: "column with name '{suffixed}' already exists\n\n\125You may want to try:\n\126- renaming the column prior to joining\n\127- using the `suffix` parameter to specify a suffix different to the default one ('_right')")128}129130Ok(Some(suffixed))131})132.collect()133}134135/// Fixes names and does coalescing of columns post-join.136fn postprocess_join(df: DataFrame, params: &EquiJoinParams) -> DataFrame {137if params.args.how == JoinType::Full && params.args.should_coalesce() {138// TODO: don't do string-based column lookups for each dataframe, pre-compute coalesce indices.139let mut coalesce_idx = 0;140df.get_columns()141.iter()142.filter_map(|c| {143if params.left_key_schema.contains(c.name()) {144let other = df145.column(&format_pl_smallstr!(146"__POLARS_COALESCE_KEYCOL{coalesce_idx}"147))148.unwrap();149coalesce_idx += 1;150return Some(coalesce_columns(&[c.clone(), other.clone()]).unwrap());151}152153if c.name().starts_with("__POLARS_COALESCE_KEYCOL") {154return None;155}156157Some(c.clone())158})159.collect()160} else {161df162}163}164165fn select_schema(schema: &Schema, selector: &[Option<PlSmallStr>]) -> Schema {166schema167.iter_fields()168.zip(selector)169.filter_map(|(f, name)| Some(f.with_name(name.clone()?)))170.collect()171}172173async fn select_keys(174df: &DataFrame,175key_selectors: &[StreamExpr],176params: &EquiJoinParams,177state: &ExecutionState,178) -> PolarsResult<HashKeys> {179let mut key_columns = Vec::new();180for selector in key_selectors {181key_columns.push(selector.evaluate(df, state).await?.into_column());182}183let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?;184Ok(HashKeys::from_df(185&keys,186params.random_state,187params.args.nulls_equal,188false,189))190}191192fn select_payload(df: DataFrame, selector: &[Option<PlSmallStr>]) -> DataFrame {193// Maintain height of zero-width dataframes.194if df.width() == 0 {195return df;196}197198df.take_columns()199.into_iter()200.zip(selector)201.filter_map(|(c, name)| Some(c.with_name(name.clone()?)))202.collect()203}204205fn estimate_cardinality(206morsels: &[Morsel],207key_selectors: &[StreamExpr],208params: &EquiJoinParams,209state: &ExecutionState,210) -> PolarsResult<f64> {211let sample_limit = *JOIN_SAMPLE_LIMIT;212if morsels.is_empty() || sample_limit == 0 {213return Ok(0.0);214}215216let mut total_height = 0;217let mut to_process_end = 0;218while to_process_end < morsels.len() && total_height < sample_limit {219total_height += morsels[to_process_end].df().height();220to_process_end += 1;221}222let last_morsel_idx = to_process_end - 1;223let last_morsel_len = morsels[last_morsel_idx].df().height();224let last_morsel_slice = last_morsel_len - total_height.saturating_sub(sample_limit);225let runtime = get_runtime();226227POOL.install(|| {228let sample_cardinality = morsels[..to_process_end]229.par_iter()230.enumerate()231.try_fold(232CardinalitySketch::new,233|mut sketch, (morsel_idx, morsel)| {234let sliced;235let df = if morsel_idx == last_morsel_idx {236sliced = morsel.df().slice(0, last_morsel_slice);237&sliced238} else {239morsel.df()240};241let hash_keys =242runtime.block_on(select_keys(df, key_selectors, params, state))?;243hash_keys.sketch_cardinality(&mut sketch);244PolarsResult::Ok(sketch)245},246)247.map(|sketch| PolarsResult::Ok(sketch?.estimate()))248.try_reduce_with(|a, b| Ok(a + b))249.unwrap()?;250Ok(sample_cardinality as f64 / total_height.min(sample_limit) as f64)251})252}253254#[derive(Default)]255struct SampleState {256left: Vec<Morsel>,257left_len: usize,258right: Vec<Morsel>,259right_len: usize,260}261262impl SampleState {263async fn sink(264mut recv: Receiver<Morsel>,265morsels: &mut Vec<Morsel>,266len: &mut usize,267this_final_len: Arc<RelaxedCell<usize>>,268other_final_len: Arc<RelaxedCell<usize>>,269) -> PolarsResult<()> {270while let Ok(mut morsel) = recv.recv().await {271*len += morsel.df().height();272if *len >= *JOIN_SAMPLE_LIMIT273|| *len274>= other_final_len275.load()276.saturating_mul(LOPSIDED_SAMPLE_FACTOR)277{278morsel.source_token().stop();279}280281drop(morsel.take_consume_token());282morsels.push(morsel);283}284this_final_len.store(*len);285Ok(())286}287288fn try_transition_to_build(289&mut self,290recv: &[PortState],291params: &mut EquiJoinParams,292state: &StreamingExecutionState,293) -> PolarsResult<Option<BuildState>> {294let left_saturated = self.left_len >= *JOIN_SAMPLE_LIMIT;295let right_saturated = self.right_len >= *JOIN_SAMPLE_LIMIT;296let left_done = recv[0] == PortState::Done || left_saturated;297let right_done = recv[1] == PortState::Done || right_saturated;298#[expect(clippy::nonminimal_bool)]299let stop_sampling = (left_done && right_done)300|| (left_done && self.right_len >= LOPSIDED_SAMPLE_FACTOR * self.left_len)301|| (right_done && self.left_len >= LOPSIDED_SAMPLE_FACTOR * self.right_len);302if !stop_sampling {303return Ok(None);304}305306if config::verbose() {307eprintln!(308"choosing build side, sample lengths are: {} vs. {}",309self.left_len, self.right_len310);311}312313let estimate_cardinalities = || {314let left_cardinality = estimate_cardinality(315&self.left,316¶ms.left_key_selectors,317params,318&state.in_memory_exec_state,319)?;320let right_cardinality = estimate_cardinality(321&self.right,322¶ms.right_key_selectors,323params,324&state.in_memory_exec_state,325)?;326if config::verbose() {327eprintln!(328"estimated cardinalities are: {left_cardinality} vs. {right_cardinality}"329);330}331PolarsResult::Ok((left_cardinality, right_cardinality))332};333334let left_is_build = match (left_saturated, right_saturated) {335// Don't bother estimating cardinality, just choose smaller side as336// we have everything in-memory anyway.337(false, false) => self.left_len < self.right_len,338339// Choose the unsaturated side, the saturated side could be340// arbitrarily big.341(false, true) => true,342(true, false) => false,343344// Estimate cardinality and choose smaller.345(true, true) => {346let (lc, rc) = estimate_cardinalities()?;347lc < rc348},349};350351if config::verbose() {352eprintln!(353"build side chosen: {}",354if left_is_build { "left" } else { "right" }355);356}357358// Transition to building state.359params.left_is_build = Some(left_is_build);360let mut sampled_build_morsels =361BufferedStream::new(core::mem::take(&mut self.left), MorselSeq::default());362let mut sampled_probe_morsels =363BufferedStream::new(core::mem::take(&mut self.right), MorselSeq::default());364if !left_is_build {365core::mem::swap(&mut sampled_build_morsels, &mut sampled_probe_morsels);366}367368let partitioner = HashPartitioner::new(state.num_pipelines, 0);369let mut build_state = BuildState::new(370state.num_pipelines,371state.num_pipelines,372sampled_probe_morsels,373);374375// Simulate the sample build morsels flowing into the build side.376if !sampled_build_morsels.is_empty() {377crate::async_executor::task_scope(|scope| {378let mut join_handles = Vec::new();379let receivers = sampled_build_morsels380.reinsert(state.num_pipelines, None, scope, &mut join_handles)381.unwrap();382383for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {384join_handles.push(scope.spawn_task(385TaskPriority::High,386BuildState::partition_and_sink(387recv,388local_builder,389partitioner.clone(),390params,391state,392),393));394}395396polars_io::pl_async::get_runtime().block_on(async move {397for handle in join_handles {398handle.await?;399}400PolarsResult::Ok(())401})402})?;403}404405Ok(Some(build_state))406}407}408409#[derive(Default)]410struct LocalBuilder {411// The complete list of morsels and their computed hashes seen by this builder.412morsels: Vec<(MorselSeq, DataFrame, HashKeys)>,413414// A cardinality sketch per partition for the keys seen by this builder.415sketch_per_p: Vec<CardinalitySketch>,416417// morsel_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i]418// for partition p, where start, stop are:419// let start = morsel_idxs_offsets[i * num_partitions + p];420// let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p];421morsel_idxs_values_per_p: Vec<Vec<IdxSize>>,422morsel_idxs_offsets_per_p: Vec<usize>,423}424425struct BuildState {426local_builders: Vec<LocalBuilder>,427sampled_probe_morsels: BufferedStream,428}429430impl BuildState {431fn new(432num_pipelines: usize,433num_partitions: usize,434sampled_probe_morsels: BufferedStream,435) -> Self {436let local_builders = (0..num_pipelines)437.map(|_| LocalBuilder {438morsels: Vec::new(),439sketch_per_p: vec![CardinalitySketch::default(); num_partitions],440morsel_idxs_values_per_p: vec![Vec::new(); num_partitions],441morsel_idxs_offsets_per_p: vec![0; num_partitions],442})443.collect();444Self {445local_builders,446sampled_probe_morsels,447}448}449450async fn partition_and_sink(451mut recv: Receiver<Morsel>,452local: &mut LocalBuilder,453partitioner: HashPartitioner,454params: &EquiJoinParams,455state: &StreamingExecutionState,456) -> PolarsResult<()> {457let track_unmatchable = params.emit_unmatched_build();458let (key_selectors, payload_selector);459if params.left_is_build.unwrap() {460payload_selector = ¶ms.left_payload_select;461key_selectors = ¶ms.left_key_selectors;462} else {463payload_selector = ¶ms.right_payload_select;464key_selectors = ¶ms.right_key_selectors;465};466467while let Ok(morsel) = recv.recv().await {468// Compute hashed keys and payload. We must rechunk the payload for469// later gathers.470let hash_keys = select_keys(471morsel.df(),472key_selectors,473params,474&state.in_memory_exec_state,475)476.await?;477let mut payload = select_payload(morsel.df().clone(), payload_selector);478payload.rechunk_mut();479480hash_keys.gen_idxs_per_partition(481&partitioner,482&mut local.morsel_idxs_values_per_p,483&mut local.sketch_per_p,484track_unmatchable,485);486487local488.morsel_idxs_offsets_per_p489.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len()));490local.morsels.push((morsel.seq(), payload, hash_keys));491}492Ok(())493}494495fn finalize_ordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState {496let track_unmatchable = params.emit_unmatched_build();497let payload_schema = if params.left_is_build.unwrap() {498¶ms.left_payload_schema499} else {500¶ms.right_payload_schema501};502503let num_partitions = self.local_builders[0].sketch_per_p.len();504let local_builders = &self.local_builders;505let probe_tables: SparseInitVec<ProbeTable> = SparseInitVec::with_capacity(num_partitions);506507POOL.scope(|s| {508for p in 0..num_partitions {509let probe_tables = &probe_tables;510s.spawn(move |_| {511// TODO: every thread does an identical linearize, we can do a single parallel one.512let mut kmerge = BinaryHeap::with_capacity(local_builders.len());513let mut cur_idx_per_loc = vec![0; local_builders.len()];514515// Compute cardinality estimate and total amount of516// payload for this partition, and initialize k-way merge.517let mut sketch = CardinalitySketch::new();518let mut payload_rows = 0;519for (l_idx, l) in local_builders.iter().enumerate() {520let Some((seq, _, _)) = l.morsels.first() else {521continue;522};523kmerge.push(Priority(Reverse(seq), l_idx));524525sketch.combine(&l.sketch_per_p[p]);526let offsets_len = l.morsel_idxs_offsets_per_p.len();527payload_rows +=528l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p];529}530531// Allocate hash table and payload builder.532let mut p_table = table.new_empty();533p_table.reserve(sketch.estimate() * 5 / 4);534let mut p_payload = DataFrameBuilder::new(payload_schema.clone());535p_payload.reserve(payload_rows);536537let mut p_seq_ids = Vec::new();538if track_unmatchable {539p_seq_ids.reserve(payload_rows);540}541542// Linearize and build.543unsafe {544let mut norm_seq_id = 0 as IdxSize;545while let Some(Priority(Reverse(_seq), l_idx)) = kmerge.pop() {546let l = local_builders.get_unchecked(l_idx);547let idx_in_l = *cur_idx_per_loc.get_unchecked(l_idx);548*cur_idx_per_loc.get_unchecked_mut(l_idx) += 1;549if let Some((next_seq, _, _)) = l.morsels.get(idx_in_l + 1) {550kmerge.push(Priority(Reverse(next_seq), l_idx));551}552553let (_mseq, payload, keys) = l.morsels.get_unchecked(idx_in_l);554let p_morsel_idxs_start =555l.morsel_idxs_offsets_per_p[idx_in_l * num_partitions + p];556let p_morsel_idxs_stop =557l.morsel_idxs_offsets_per_p[(idx_in_l + 1) * num_partitions + p];558let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]559[p_morsel_idxs_start..p_morsel_idxs_stop];560p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable);561p_payload.gather_extend(payload, p_morsel_idxs, ShareStrategy::Never);562563if track_unmatchable {564p_seq_ids.resize(p_payload.len(), norm_seq_id);565norm_seq_id += 1;566}567}568}569570probe_tables571.try_set(572p,573ProbeTable {574hash_table: p_table,575payload: p_payload.freeze(),576seq_ids: p_seq_ids,577},578)579.ok()580.unwrap();581});582}583});584585ProbeState {586table_per_partition: probe_tables.try_assume_init().ok().unwrap(),587max_seq_sent: MorselSeq::default(),588sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels),589unordered_morsel_seq: AtomicU64::new(0),590}591}592593fn finalize_unordered(&mut self, params: &EquiJoinParams, table: &dyn IdxTable) -> ProbeState {594let track_unmatchable = params.emit_unmatched_build();595let payload_schema = if params.left_is_build.unwrap() {596¶ms.left_payload_schema597} else {598¶ms.right_payload_schema599};600601// To reduce maximum memory usage we want to drop the morsels602// as soon as they're processed, so we move into Arcs. The drops might603// also be expensive, so instead of directly dropping we put that on604// a work queue.605let morsels_per_local_builder = self606.local_builders607.iter_mut()608.map(|b| Arc::new(core::mem::take(&mut b.morsels)))609.collect_vec();610let (morsel_drop_q_send, morsel_drop_q_recv) =611async_channel::bounded(morsels_per_local_builder.len());612let num_partitions = self.local_builders[0].sketch_per_p.len();613let local_builders = &self.local_builders;614let probe_tables: SparseInitVec<ProbeTable> = SparseInitVec::with_capacity(num_partitions);615616async_executor::task_scope(|s| {617// Wrap in outer Arc to move to each thread, performing the618// expensive clone on that thread.619let arc_morsels_per_local_builder = Arc::new(morsels_per_local_builder);620let mut join_handles = Vec::new();621for p in 0..num_partitions {622let arc_morsels_per_local_builder = Arc::clone(&arc_morsels_per_local_builder);623let morsel_drop_q_send = morsel_drop_q_send.clone();624let morsel_drop_q_recv = morsel_drop_q_recv.clone();625let probe_tables = &probe_tables;626join_handles.push(s.spawn_task(TaskPriority::High, async move {627// Extract from outer arc and drop outer arc.628let morsels_per_local_builder =629Arc::unwrap_or_clone(arc_morsels_per_local_builder);630631// Compute cardinality estimate and total amount of632// payload for this partition.633let mut sketch = CardinalitySketch::new();634let mut payload_rows = 0;635for l in local_builders {636sketch.combine(&l.sketch_per_p[p]);637let offsets_len = l.morsel_idxs_offsets_per_p.len();638payload_rows +=639l.morsel_idxs_offsets_per_p[offsets_len - num_partitions + p];640}641642// Allocate hash table and payload builder.643let mut p_table = table.new_empty();644p_table.reserve(sketch.estimate() * 5 / 4);645let mut p_payload = DataFrameBuilder::new(payload_schema.clone());646p_payload.reserve(payload_rows);647648// Build.649let mut skip_drop_attempt = false;650for (l, l_morsels) in local_builders.iter().zip(morsels_per_local_builder) {651// Try to help with dropping the processed morsels.652if !skip_drop_attempt {653drop(morsel_drop_q_recv.try_recv());654}655656for (i, morsel) in l_morsels.iter().enumerate() {657let (_mseq, payload, keys) = morsel;658unsafe {659let p_morsel_idxs_start =660l.morsel_idxs_offsets_per_p[i * num_partitions + p];661let p_morsel_idxs_stop =662l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p];663let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]664[p_morsel_idxs_start..p_morsel_idxs_stop];665p_table.insert_keys_subset(keys, p_morsel_idxs, track_unmatchable);666p_payload.gather_extend(667payload,668p_morsel_idxs,669ShareStrategy::Never,670);671}672}673674if let Some(l) = Arc::into_inner(l_morsels) {675// If we're the last thread to process this set of morsels we're probably676// falling behind the rest, since the drop can be quite expensive we skip677// a drop attempt hoping someone else will pick up the slack.678drop(morsel_drop_q_send.try_send(l));679skip_drop_attempt = true;680} else {681skip_drop_attempt = false;682}683}684685// We're done, help others out by doing drops.686drop(morsel_drop_q_send); // So we don't deadlock trying to receive from ourselves.687while let Ok(l_morsels) = morsel_drop_q_recv.recv().await {688drop(l_morsels);689}690691probe_tables692.try_set(693p,694ProbeTable {695hash_table: p_table,696payload: p_payload.freeze(),697seq_ids: Vec::new(),698},699)700.ok()701.unwrap();702}));703}704705// Drop outer arc after spawning each thread so the inner arcs706// can get dropped as soon as they're processed. We also have to707// drop the drop queue sender so we don't deadlock waiting for it708// to end.709drop(arc_morsels_per_local_builder);710drop(morsel_drop_q_send);711712polars_io::pl_async::get_runtime().block_on(async move {713for handle in join_handles {714handle.await;715}716});717});718719ProbeState {720table_per_partition: probe_tables.try_assume_init().ok().unwrap(),721max_seq_sent: MorselSeq::default(),722sampled_probe_morsels: core::mem::take(&mut self.sampled_probe_morsels),723unordered_morsel_seq: AtomicU64::new(0),724}725}726}727728struct ProbeTable {729hash_table: Box<dyn IdxTable>,730payload: DataFrame,731seq_ids: Vec<IdxSize>,732}733734struct ProbeState {735table_per_partition: Vec<ProbeTable>,736max_seq_sent: MorselSeq,737sampled_probe_morsels: BufferedStream,738739// For unordered joins we relabel output morsels to speed up the linearizer.740unordered_morsel_seq: AtomicU64,741}742743impl ProbeState {744/// Returns the max morsel sequence sent.745async fn partition_and_probe(746mut recv: Receiver<Morsel>,747mut send: Sender<Morsel>,748partitions: &[ProbeTable],749unordered_morsel_seq: &AtomicU64,750partitioner: HashPartitioner,751params: &EquiJoinParams,752state: &StreamingExecutionState,753) -> PolarsResult<MorselSeq> {754// TODO: shuffle after partitioning and keep probe tables thread-local.755let mut partition_idxs = vec![Vec::new(); partitioner.num_partitions()];756let mut probe_partitions = Vec::new();757let mut materialized_idxsize_range = Vec::new();758let mut table_match = Vec::new();759let mut probe_match = Vec::new();760let mut max_seq = MorselSeq::default();761762let probe_limit = get_ideal_morsel_size() as IdxSize;763let mark_matches = params.emit_unmatched_build();764let emit_unmatched = params.emit_unmatched_probe();765766let (key_selectors, payload_selector, build_payload_schema, probe_payload_schema);767if params.left_is_build.unwrap() {768key_selectors = ¶ms.right_key_selectors;769payload_selector = ¶ms.right_payload_select;770build_payload_schema = ¶ms.left_payload_schema;771probe_payload_schema = ¶ms.right_payload_schema;772} else {773key_selectors = ¶ms.left_key_selectors;774payload_selector = ¶ms.left_payload_select;775build_payload_schema = ¶ms.right_payload_schema;776probe_payload_schema = ¶ms.left_payload_schema;777};778779let mut build_out = DataFrameBuilder::new(build_payload_schema.clone());780let mut probe_out = DataFrameBuilder::new(probe_payload_schema.clone());781782// A simple estimate used to size reserves.783let mut selectivity_estimate = 1.0;784let mut selectivity_estimate_confidence = 0.0;785786while let Ok(morsel) = recv.recv().await {787// Compute hashed keys and payload.788let (df, in_seq, src_token, wait_token) = morsel.into_inner();789790let df_height = df.height();791if df_height == 0 {792continue;793}794795let hash_keys =796select_keys(&df, key_selectors, params, &state.in_memory_exec_state).await?;797let mut payload = select_payload(df, payload_selector);798let mut payload_rechunked = false; // We don't eagerly rechunk because there might be no matches.799let mut total_matches = 0;800801// Use selectivity estimate to reserve for morsel builders.802let max_match_per_key_est = (selectivity_estimate * 1.2) as usize + 16;803let out_est_size = ((selectivity_estimate * 1.2 * df_height as f64) as usize)804.min(probe_limit as usize);805build_out.reserve(out_est_size + max_match_per_key_est);806807unsafe {808let mut new_morsel =809|build: &mut DataFrameBuilder, probe: &mut DataFrameBuilder| {810let mut build_df = build.freeze_reset();811let mut probe_df = probe.freeze_reset();812let out_df = if params.left_is_build.unwrap() {813build_df.hstack_mut_unchecked(probe_df.get_columns());814build_df815} else {816probe_df.hstack_mut_unchecked(build_df.get_columns());817probe_df818};819let out_df = postprocess_join(out_df, params);820let out_seq = if params.preserve_order_probe {821in_seq822} else {823MorselSeq::new(unordered_morsel_seq.fetch_add(1, Ordering::Relaxed))824};825max_seq = out_seq;826Morsel::new(out_df, out_seq, src_token.clone())827};828829if params.preserve_order_probe {830// To preserve the order we can't do bulk probes per partition and must follow831// the order of the probe morsel. We can still group probes that are832// consecutively on the same partition.833probe_partitions.clear();834hash_keys.gen_partitions(&partitioner, &mut probe_partitions, emit_unmatched);835836let mut probe_group_start = 0;837while probe_group_start < probe_partitions.len() {838let p_idx = probe_partitions[probe_group_start];839let mut probe_group_end = probe_group_start + 1;840while probe_partitions.get(probe_group_end) == Some(&p_idx) {841probe_group_end += 1;842}843let Some(p) = partitions.get(p_idx as usize) else {844probe_group_start = probe_group_end;845continue;846};847848materialized_idxsize_range.extend(849materialized_idxsize_range.len() as IdxSize..probe_group_end as IdxSize,850);851852while probe_group_start < probe_group_end {853let matches_before_limit = probe_limit - probe_match.len() as IdxSize;854table_match.clear();855probe_group_start += p.hash_table.probe_subset(856&hash_keys,857&materialized_idxsize_range[probe_group_start..probe_group_end],858&mut table_match,859&mut probe_match,860mark_matches,861emit_unmatched,862matches_before_limit,863) as usize;864865if emit_unmatched {866build_out.opt_gather_extend(867&p.payload,868&table_match,869ShareStrategy::Always,870);871} else {872build_out.gather_extend(873&p.payload,874&table_match,875ShareStrategy::Always,876);877};878879if probe_match.len() >= probe_limit as usize880|| probe_group_start == probe_partitions.len()881{882if !payload_rechunked {883payload.rechunk_mut();884payload_rechunked = true;885}886probe_out.gather_extend(887&payload,888&probe_match,889ShareStrategy::Always,890);891let out_len = probe_match.len();892probe_match.clear();893let out_morsel = new_morsel(&mut build_out, &mut probe_out);894if send.send(out_morsel).await.is_err() {895return Ok(max_seq);896}897if probe_group_end != probe_partitions.len() {898// We had enough matches to need a mid-partition flush, let's assume there are a lot of899// matches and just do a large reserve.900let old_est = probe_limit as usize + max_match_per_key_est;901build_out.reserve(old_est.max(out_len + 16));902}903}904}905}906} else {907// Partition and probe the tables.908for p in partition_idxs.iter_mut() {909p.clear();910}911hash_keys.gen_idxs_per_partition(912&partitioner,913&mut partition_idxs,914&mut [],915emit_unmatched,916);917918for (p, idxs_in_p) in partitions.iter().zip(&partition_idxs) {919let mut offset = 0;920while offset < idxs_in_p.len() {921let matches_before_limit = probe_limit - probe_match.len() as IdxSize;922table_match.clear();923offset += p.hash_table.probe_subset(924&hash_keys,925&idxs_in_p[offset..],926&mut table_match,927&mut probe_match,928mark_matches,929emit_unmatched,930matches_before_limit,931) as usize;932933if table_match.is_empty() {934continue;935}936total_matches += table_match.len();937938if emit_unmatched {939build_out.opt_gather_extend(940&p.payload,941&table_match,942ShareStrategy::Always,943);944} else {945build_out.gather_extend(946&p.payload,947&table_match,948ShareStrategy::Always,949);950};951952if probe_match.len() >= probe_limit as usize {953if !payload_rechunked {954payload.rechunk_mut();955payload_rechunked = true;956}957probe_out.gather_extend(958&payload,959&probe_match,960ShareStrategy::Always,961);962let out_len = probe_match.len();963probe_match.clear();964let out_morsel = new_morsel(&mut build_out, &mut probe_out);965if send.send(out_morsel).await.is_err() {966return Ok(max_seq);967}968// We had enough matches to need a mid-partition flush, let's assume there are a lot of969// matches and just do a large reserve.970let old_est = probe_limit as usize + max_match_per_key_est;971build_out.reserve(old_est.max(out_len + 16));972}973}974}975}976977if !probe_match.is_empty() {978if !payload_rechunked {979payload.rechunk_mut();980}981probe_out.gather_extend(&payload, &probe_match, ShareStrategy::Always);982probe_match.clear();983let out_morsel = new_morsel(&mut build_out, &mut probe_out);984if send.send(out_morsel).await.is_err() {985return Ok(max_seq);986}987}988}989990drop(wait_token);991992// Move selectivity estimate a bit towards latest value. Allows rapid changes at first.993// TODO: implement something more re-usable and robust.994selectivity_estimate = selectivity_estimate_confidence * selectivity_estimate995+ (1.0 - selectivity_estimate_confidence)996* (total_matches as f64 / df_height as f64);997selectivity_estimate_confidence = (selectivity_estimate_confidence + 0.1).min(0.8);998}9991000Ok(max_seq)1001}10021003fn ordered_unmatched(&mut self, params: &EquiJoinParams) -> DataFrame {1004// TODO: parallelize this operator.10051006let build_payload_schema = if params.left_is_build.unwrap() {1007¶ms.left_payload_schema1008} else {1009¶ms.right_payload_schema1010};10111012let mut unmarked_idxs = Vec::new();1013let mut linearized_idxs = Vec::new();10141015for (p_idx, p) in self.table_per_partition.iter().enumerate_idx() {1016p.hash_table1017.unmarked_keys(&mut unmarked_idxs, 0, IdxSize::MAX);1018linearized_idxs.extend(1019unmarked_idxs1020.iter()1021.map(|i| (unsafe { *p.seq_ids.get_unchecked(*i as usize) }, p_idx, *i)),1022);1023}10241025linearized_idxs.sort_by_key(|(seq_id, _, _)| *seq_id);10261027unsafe {1028let mut build_out = DataFrameBuilder::new(build_payload_schema.clone());1029build_out.reserve(linearized_idxs.len());10301031// Group indices from the same partition.1032let mut group_start = 0;1033let mut gather_idxs = Vec::new();1034while group_start < linearized_idxs.len() {1035gather_idxs.clear();10361037let (_seq, p_idx, idx_in_p) = linearized_idxs[group_start];1038gather_idxs.push(idx_in_p);1039let mut group_end = group_start + 1;1040while group_end < linearized_idxs.len() && linearized_idxs[group_end].1 == p_idx {1041gather_idxs.push(linearized_idxs[group_end].2);1042group_end += 1;1043}10441045build_out.gather_extend(1046&self.table_per_partition[p_idx as usize].payload,1047&gather_idxs,1048ShareStrategy::Never, // Don't keep entire table alive for unmatched indices.1049);10501051group_start = group_end;1052}10531054let mut build_df = build_out.freeze();1055let out_df = if params.left_is_build.unwrap() {1056let probe_df =1057DataFrame::full_null(¶ms.right_payload_schema, build_df.height());1058build_df.hstack_mut_unchecked(probe_df.get_columns());1059build_df1060} else {1061let mut probe_df =1062DataFrame::full_null(¶ms.left_payload_schema, build_df.height());1063probe_df.hstack_mut_unchecked(build_df.get_columns());1064probe_df1065};1066postprocess_join(out_df, params)1067}1068}1069}10701071impl Drop for ProbeState {1072fn drop(&mut self) {1073POOL.install(|| {1074// Parallel drop as the state might be quite big.1075self.table_per_partition.par_drain(..).for_each(drop);1076})1077}1078}10791080struct EmitUnmatchedState {1081partitions: Vec<ProbeTable>,1082active_partition_idx: usize,1083offset_in_active_p: usize,1084morsel_seq: MorselSeq,1085}10861087impl EmitUnmatchedState {1088async fn emit_unmatched(1089&mut self,1090mut send: Sender<Morsel>,1091params: &EquiJoinParams,1092num_pipelines: usize,1093) -> PolarsResult<()> {1094let total_len: usize = self1095.partitions1096.iter()1097.map(|p| p.hash_table.num_keys() as usize)1098.sum();1099let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1);1100let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines);1101let morsel_size = total_len.div_ceil(morsel_count).max(1);11021103let wait_group = WaitGroup::default();1104let source_token = SourceToken::new();1105let mut unmarked_idxs = Vec::new();1106while let Some(p) = self.partitions.get(self.active_partition_idx) {1107loop {1108// Generate a chunk of unmarked key indices.1109self.offset_in_active_p += p.hash_table.unmarked_keys(1110&mut unmarked_idxs,1111self.offset_in_active_p as IdxSize,1112morsel_size as IdxSize,1113) as usize;1114if unmarked_idxs.is_empty() {1115break;1116}11171118// Gather and create full-null counterpart.1119let out_df = unsafe {1120let mut build_df = p.payload.take_slice_unchecked_impl(&unmarked_idxs, false);1121let len = build_df.height();1122if params.left_is_build.unwrap() {1123let probe_df = DataFrame::full_null(¶ms.right_payload_schema, len);1124build_df.hstack_mut_unchecked(probe_df.get_columns());1125build_df1126} else {1127let mut probe_df = DataFrame::full_null(¶ms.left_payload_schema, len);1128probe_df.hstack_mut_unchecked(build_df.get_columns());1129probe_df1130}1131};1132let out_df = postprocess_join(out_df, params);11331134// Send and wait until consume token is consumed.1135let mut morsel = Morsel::new(out_df, self.morsel_seq, source_token.clone());1136self.morsel_seq = self.morsel_seq.successor();1137morsel.set_consume_token(wait_group.token());1138if send.send(morsel).await.is_err() {1139return Ok(());1140}11411142wait_group.wait().await;1143if source_token.stop_requested() {1144return Ok(());1145}1146}11471148self.active_partition_idx += 1;1149self.offset_in_active_p = 0;1150}11511152Ok(())1153}1154}11551156enum EquiJoinState {1157Sample(SampleState),1158Build(BuildState),1159Probe(ProbeState),1160EmitUnmatchedBuild(EmitUnmatchedState),1161EmitUnmatchedBuildInOrder(InMemorySourceNode),1162Done,1163}11641165pub struct EquiJoinNode {1166state: EquiJoinState,1167params: EquiJoinParams,1168table: Box<dyn IdxTable>,1169}11701171impl EquiJoinNode {1172#[allow(clippy::too_many_arguments)]1173pub fn new(1174left_input_schema: Arc<Schema>,1175right_input_schema: Arc<Schema>,1176left_key_schema: Arc<Schema>,1177right_key_schema: Arc<Schema>,1178unique_key_schema: Arc<Schema>,1179left_key_selectors: Vec<StreamExpr>,1180right_key_selectors: Vec<StreamExpr>,1181args: JoinArgs,1182num_pipelines: usize,1183) -> PolarsResult<Self> {1184let left_is_build = match args.maintain_order {1185MaintainOrderJoin::None => {1186if *JOIN_SAMPLE_LIMIT == 0 {1187Some(true)1188} else {1189None1190}1191},1192MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => Some(false),1193MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => Some(true),1194};11951196let preserve_order_probe = args.maintain_order != MaintainOrderJoin::None;1197let preserve_order_build = matches!(1198args.maintain_order,1199MaintainOrderJoin::LeftRight | MaintainOrderJoin::RightLeft1200);12011202let left_payload_select = compute_payload_selector(1203&left_input_schema,1204&right_input_schema,1205&left_key_schema,1206&right_key_schema,1207true,1208&args,1209)?;1210let right_payload_select = compute_payload_selector(1211&right_input_schema,1212&left_input_schema,1213&right_key_schema,1214&left_key_schema,1215false,1216&args,1217)?;12181219let state = if left_is_build.is_some() {1220EquiJoinState::Build(BuildState::new(1221num_pipelines,1222num_pipelines,1223BufferedStream::default(),1224))1225} else {1226EquiJoinState::Sample(SampleState::default())1227};12281229let left_payload_schema = Arc::new(select_schema(&left_input_schema, &left_payload_select));1230let right_payload_schema =1231Arc::new(select_schema(&right_input_schema, &right_payload_select));1232Ok(Self {1233state,1234params: EquiJoinParams {1235left_is_build,1236preserve_order_build,1237preserve_order_probe,1238left_key_schema,1239left_key_selectors,1240right_key_schema,1241right_key_selectors,1242left_payload_select,1243right_payload_select,1244left_payload_schema,1245right_payload_schema,1246args,1247random_state: PlRandomState::default(),1248},1249table: new_idx_table(unique_key_schema),1250})1251}1252}12531254impl ComputeNode for EquiJoinNode {1255fn name(&self) -> &str {1256"equi-join"1257}12581259fn update_state(1260&mut self,1261recv: &mut [PortState],1262send: &mut [PortState],1263state: &StreamingExecutionState,1264) -> PolarsResult<()> {1265assert!(recv.len() == 2 && send.len() == 1);12661267// If the output doesn't want any more data, transition to being done.1268if send[0] == PortState::Done {1269self.state = EquiJoinState::Done;1270}12711272// If we are sampling and both sides are done/filled, transition to building.1273if let EquiJoinState::Sample(sample_state) = &mut self.state {1274if let Some(build_state) =1275sample_state.try_transition_to_build(recv, &mut self.params, state)?1276{1277self.state = EquiJoinState::Build(build_state);1278}1279}12801281let build_idx = if self.params.left_is_build == Some(true) {128201283} else {128411285};1286let probe_idx = 1 - build_idx;12871288// If we are building and the build input is done, transition to probing.1289if let EquiJoinState::Build(build_state) = &mut self.state {1290if recv[build_idx] == PortState::Done {1291let probe_state = if self.params.preserve_order_build {1292build_state.finalize_ordered(&self.params, &*self.table)1293} else {1294build_state.finalize_unordered(&self.params, &*self.table)1295};1296self.state = EquiJoinState::Probe(probe_state);1297}1298}12991300// If we are probing and the probe input is done, emit unmatched if1301// necessary, otherwise we're done.1302if let EquiJoinState::Probe(probe_state) = &mut self.state {1303let samples_consumed = probe_state.sampled_probe_morsels.is_empty();1304if samples_consumed && recv[probe_idx] == PortState::Done {1305if self.params.emit_unmatched_build() {1306if self.params.preserve_order_build {1307let unmatched = probe_state.ordered_unmatched(&self.params);1308let src = InMemorySourceNode::new(1309Arc::new(unmatched),1310probe_state.max_seq_sent.successor(),1311);1312self.state = EquiJoinState::EmitUnmatchedBuildInOrder(src);1313} else {1314self.state = EquiJoinState::EmitUnmatchedBuild(EmitUnmatchedState {1315partitions: core::mem::take(&mut probe_state.table_per_partition),1316active_partition_idx: 0,1317offset_in_active_p: 0,1318morsel_seq: probe_state.max_seq_sent.successor(),1319});1320}1321} else {1322self.state = EquiJoinState::Done;1323}1324}1325}13261327// Finally, check if we are done emitting unmatched keys.1328if let EquiJoinState::EmitUnmatchedBuild(emit_state) = &mut self.state {1329if emit_state.active_partition_idx >= emit_state.partitions.len() {1330self.state = EquiJoinState::Done;1331}1332}13331334match &mut self.state {1335EquiJoinState::Sample(sample_state) => {1336send[0] = PortState::Blocked;1337if recv[0] != PortState::Done {1338recv[0] = if sample_state.left_len < *JOIN_SAMPLE_LIMIT {1339PortState::Ready1340} else {1341PortState::Blocked1342};1343}1344if recv[1] != PortState::Done {1345recv[1] = if sample_state.right_len < *JOIN_SAMPLE_LIMIT {1346PortState::Ready1347} else {1348PortState::Blocked1349};1350}1351},1352EquiJoinState::Build(_) => {1353send[0] = PortState::Blocked;1354if recv[build_idx] != PortState::Done {1355recv[build_idx] = PortState::Ready;1356}1357if recv[probe_idx] != PortState::Done {1358recv[probe_idx] = PortState::Blocked;1359}1360},1361EquiJoinState::Probe(probe_state) => {1362if recv[probe_idx] != PortState::Done {1363core::mem::swap(&mut send[0], &mut recv[probe_idx]);1364} else {1365let samples_consumed = probe_state.sampled_probe_morsels.is_empty();1366send[0] = if samples_consumed {1367PortState::Done1368} else {1369PortState::Ready1370};1371}1372recv[build_idx] = PortState::Done;1373},1374EquiJoinState::EmitUnmatchedBuild(_) => {1375send[0] = PortState::Ready;1376recv[build_idx] = PortState::Done;1377recv[probe_idx] = PortState::Done;1378},1379EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => {1380recv[build_idx] = PortState::Done;1381recv[probe_idx] = PortState::Done;1382src_node.update_state(&mut [], &mut send[0..1], state)?;1383if send[0] == PortState::Done {1384self.state = EquiJoinState::Done;1385}1386},1387EquiJoinState::Done => {1388send[0] = PortState::Done;1389recv[0] = PortState::Done;1390recv[1] = PortState::Done;1391},1392}1393Ok(())1394}13951396fn is_memory_intensive_pipeline_blocker(&self) -> bool {1397matches!(1398self.state,1399EquiJoinState::Sample { .. } | EquiJoinState::Build { .. }1400)1401}14021403fn spawn<'env, 's>(1404&'env mut self,1405scope: &'s TaskScope<'s, 'env>,1406recv_ports: &mut [Option<RecvPort<'_>>],1407send_ports: &mut [Option<SendPort<'_>>],1408state: &'s StreamingExecutionState,1409join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,1410) {1411assert!(recv_ports.len() == 2);1412assert!(send_ports.len() == 1);14131414let build_idx = if self.params.left_is_build == Some(true) {141501416} else {141711418};1419let probe_idx = 1 - build_idx;14201421match &mut self.state {1422EquiJoinState::Sample(sample_state) => {1423assert!(send_ports[0].is_none());1424let left_final_len = Arc::new(RelaxedCell::from(if recv_ports[0].is_none() {1425sample_state.left_len1426} else {1427usize::MAX1428}));1429let right_final_len = Arc::new(RelaxedCell::from(if recv_ports[1].is_none() {1430sample_state.right_len1431} else {1432usize::MAX1433}));14341435if let Some(left_recv) = recv_ports[0].take() {1436join_handles.push(scope.spawn_task(1437TaskPriority::High,1438SampleState::sink(1439left_recv.serial(),1440&mut sample_state.left,1441&mut sample_state.left_len,1442left_final_len.clone(),1443right_final_len.clone(),1444),1445));1446}1447if let Some(right_recv) = recv_ports[1].take() {1448join_handles.push(scope.spawn_task(1449TaskPriority::High,1450SampleState::sink(1451right_recv.serial(),1452&mut sample_state.right,1453&mut sample_state.right_len,1454right_final_len,1455left_final_len,1456),1457));1458}1459},1460EquiJoinState::Build(build_state) => {1461assert!(send_ports[0].is_none());1462assert!(recv_ports[probe_idx].is_none());1463let receivers = recv_ports[build_idx].take().unwrap().parallel();14641465let partitioner = HashPartitioner::new(state.num_pipelines, 0);1466for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {1467join_handles.push(scope.spawn_task(1468TaskPriority::High,1469BuildState::partition_and_sink(1470recv,1471local_builder,1472partitioner.clone(),1473&self.params,1474state,1475),1476));1477}1478},1479EquiJoinState::Probe(probe_state) => {1480assert!(recv_ports[build_idx].is_none());1481let senders = send_ports[0].take().unwrap().parallel();1482let receivers = probe_state1483.sampled_probe_morsels1484.reinsert(1485state.num_pipelines,1486recv_ports[probe_idx].take(),1487scope,1488join_handles,1489)1490.unwrap();14911492let partitioner = HashPartitioner::new(state.num_pipelines, 0);1493let probe_tasks = receivers1494.into_iter()1495.zip(senders)1496.map(|(recv, send)| {1497scope.spawn_task(1498TaskPriority::High,1499ProbeState::partition_and_probe(1500recv,1501send,1502&probe_state.table_per_partition,1503&probe_state.unordered_morsel_seq,1504partitioner.clone(),1505&self.params,1506state,1507),1508)1509})1510.collect_vec();15111512let max_seq_sent = &mut probe_state.max_seq_sent;1513join_handles.push(scope.spawn_task(TaskPriority::High, async move {1514for probe_task in probe_tasks {1515*max_seq_sent = (*max_seq_sent).max(probe_task.await?);1516}1517Ok(())1518}));1519},1520EquiJoinState::EmitUnmatchedBuild(emit_state) => {1521assert!(recv_ports[build_idx].is_none());1522assert!(recv_ports[probe_idx].is_none());1523let send = send_ports[0].take().unwrap().serial();1524join_handles.push(scope.spawn_task(1525TaskPriority::Low,1526emit_state.emit_unmatched(send, &self.params, state.num_pipelines),1527));1528},1529EquiJoinState::EmitUnmatchedBuildInOrder(src_node) => {1530assert!(recv_ports[build_idx].is_none());1531assert!(recv_ports[probe_idx].is_none());1532src_node.spawn(scope, &mut [], send_ports, state, join_handles);1533},1534EquiJoinState::Done => unreachable!(),1535}1536}1537}153815391540