Path: blob/main/crates/polars-stream/src/nodes/joins/semi_anti_join.rs
6939 views
use std::sync::Arc;12use arrow::array::BooleanArray;3use arrow::bitmap::BitmapBuilder;4use polars_core::prelude::*;5use polars_core::schema::Schema;6use polars_expr::groups::{Grouper, new_hash_grouper};7use polars_expr::hash_keys::HashKeys;8use polars_ops::frame::{JoinArgs, JoinType};9use polars_utils::IdxSize;10use polars_utils::cardinality_sketch::CardinalitySketch;11use polars_utils::hashing::HashPartitioner;12use polars_utils::itertools::Itertools;13use polars_utils::sparse_init_vec::SparseInitVec;1415use crate::async_executor;16use crate::async_primitives::connector::{Receiver, Sender};17use crate::expression::StreamExpr;18use crate::nodes::compute_node_prelude::*;1920async fn select_keys(21df: &DataFrame,22key_selectors: &[StreamExpr],23params: &SemiAntiJoinParams,24state: &ExecutionState,25) -> PolarsResult<HashKeys> {26let mut key_columns = Vec::new();27for selector in key_selectors {28key_columns.push(selector.evaluate(df, state).await?.into_column());29}30let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?;31Ok(HashKeys::from_df(32&keys,33params.random_state,34params.nulls_equal,35false,36))37}3839struct SemiAntiJoinParams {40left_is_build: bool,41left_key_selectors: Vec<StreamExpr>,42right_key_selectors: Vec<StreamExpr>,43nulls_equal: bool,44is_anti: bool,45return_bool: bool,46random_state: PlRandomState,47}4849pub struct SemiAntiJoinNode {50state: SemiAntiJoinState,51params: SemiAntiJoinParams,52grouper: Box<dyn Grouper>,53}5455impl SemiAntiJoinNode {56pub fn new(57unique_key_schema: Arc<Schema>,58left_key_selectors: Vec<StreamExpr>,59right_key_selectors: Vec<StreamExpr>,60args: JoinArgs,61return_bool: bool,62num_pipelines: usize,63) -> PolarsResult<Self> {64let left_is_build = false;65let is_anti = args.how == JoinType::Anti;6667let state = SemiAntiJoinState::Build(BuildState::new(num_pipelines, num_pipelines));6869Ok(Self {70state,71params: SemiAntiJoinParams {72left_is_build,73left_key_selectors,74right_key_selectors,75random_state: PlRandomState::default(),76nulls_equal: args.nulls_equal,77return_bool,78is_anti,79},80grouper: new_hash_grouper(unique_key_schema),81})82}83}8485enum SemiAntiJoinState {86Build(BuildState),87Probe(ProbeState),88Done,89}9091#[derive(Default)]92struct LocalBuilder {93// The complete list of keys as seen by this builder.94keys: Vec<HashKeys>,9596// A cardinality sketch per partition for the keys seen by this builder.97sketch_per_p: Vec<CardinalitySketch>,9899// key_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i]100// for partition p, where start, stop are:101// let start = key_idxs_offsets[i * num_partitions + p];102// let stop = key_idxs_offsets[(i + 1) * num_partitions + p];103key_idxs_values_per_p: Vec<Vec<IdxSize>>,104key_idxs_offsets_per_p: Vec<usize>,105}106107struct BuildState {108local_builders: Vec<LocalBuilder>,109}110111impl BuildState {112fn new(num_pipelines: usize, num_partitions: usize) -> Self {113let local_builders = (0..num_pipelines)114.map(|_| LocalBuilder {115keys: Vec::new(),116sketch_per_p: vec![CardinalitySketch::default(); num_partitions],117key_idxs_values_per_p: vec![Vec::new(); num_partitions],118key_idxs_offsets_per_p: vec![0; num_partitions],119})120.collect();121Self { local_builders }122}123124async fn partition_and_sink(125mut recv: Receiver<Morsel>,126local: &mut LocalBuilder,127partitioner: HashPartitioner,128params: &SemiAntiJoinParams,129state: &StreamingExecutionState,130) -> PolarsResult<()> {131let key_selectors = if params.left_is_build {132¶ms.left_key_selectors133} else {134¶ms.right_key_selectors135};136137while let Ok(morsel) = recv.recv().await {138let hash_keys = select_keys(139morsel.df(),140key_selectors,141params,142&state.in_memory_exec_state,143)144.await?;145146hash_keys.gen_idxs_per_partition(147&partitioner,148&mut local.key_idxs_values_per_p,149&mut local.sketch_per_p,150false,151);152153local154.key_idxs_offsets_per_p155.extend(local.key_idxs_values_per_p.iter().map(|vp| vp.len()));156local.keys.push(hash_keys);157}158Ok(())159}160161fn finalize(&mut self, grouper: &dyn Grouper) -> ProbeState {162// To reduce maximum memory usage we want to drop the original keys163// as soon as they're processed, so we move into Arcs. The drops might164// also be expensive, so instead of directly dropping we put that on165// a work queue.166let keys_per_local_builder = self167.local_builders168.iter_mut()169.map(|b| Arc::new(core::mem::take(&mut b.keys)))170.collect_vec();171let (key_drop_q_send, key_drop_q_recv) =172async_channel::bounded(keys_per_local_builder.len());173let num_partitions = self.local_builders[0].sketch_per_p.len();174let local_builders = &self.local_builders;175let groupers: SparseInitVec<Box<dyn Grouper>> =176SparseInitVec::with_capacity(num_partitions);177178async_executor::task_scope(|s| {179// Wrap in outer Arc to move to each thread, performing the180// expensive clone on that thread.181let arc_keys_per_local_builder = Arc::new(keys_per_local_builder);182let mut join_handles = Vec::new();183for p in 0..num_partitions {184let arc_keys_per_local_builder = Arc::clone(&arc_keys_per_local_builder);185let key_drop_q_send = key_drop_q_send.clone();186let key_drop_q_recv = key_drop_q_recv.clone();187let groupers = &groupers;188join_handles.push(s.spawn_task(TaskPriority::High, async move {189// Extract from outer arc and drop outer arc.190let keys_per_local_builder = Arc::unwrap_or_clone(arc_keys_per_local_builder);191192// Compute cardinality estimate.193let mut sketch = CardinalitySketch::new();194for l in local_builders {195sketch.combine(&l.sketch_per_p[p]);196}197198// Allocate hash table.199let mut p_grouper = grouper.new_empty();200p_grouper.reserve(sketch.estimate() * 5 / 4);201202// Build.203let mut skip_drop_attempt = false;204for (l, l_keys) in local_builders.iter().zip(keys_per_local_builder) {205// Try to help with dropping the processed keys.206if !skip_drop_attempt {207drop(key_drop_q_recv.try_recv());208}209210for (i, keys) in l_keys.iter().enumerate() {211unsafe {212let p_key_idxs_start =213l.key_idxs_offsets_per_p[i * num_partitions + p];214let p_key_idxs_stop =215l.key_idxs_offsets_per_p[(i + 1) * num_partitions + p];216let p_key_idxs =217&l.key_idxs_values_per_p[p][p_key_idxs_start..p_key_idxs_stop];218p_grouper.insert_keys_subset(keys, p_key_idxs, None);219}220}221222if let Some(l) = Arc::into_inner(l_keys) {223// If we're the last thread to process this set of keys we're probably224// falling behind the rest, since the drop can be quite expensive we skip225// a drop attempt hoping someone else will pick up the slack.226drop(key_drop_q_send.try_send(l));227skip_drop_attempt = true;228} else {229skip_drop_attempt = false;230}231}232233// We're done, help others out by doing drops.234drop(key_drop_q_send); // So we don't deadlock trying to receive from ourselves.235while let Ok(l_keys) = key_drop_q_recv.recv().await {236drop(l_keys);237}238239groupers.try_set(p, p_grouper).ok().unwrap();240}));241}242243// Drop outer arc after spawning each thread so the inner arcs244// can get dropped as soon as they're processed. We also have to245// drop the drop queue sender so we don't deadlock waiting for it246// to end.247drop(arc_keys_per_local_builder);248drop(key_drop_q_send);249250polars_io::pl_async::get_runtime().block_on(async move {251for handle in join_handles {252handle.await;253}254});255});256257ProbeState {258grouper_per_partition: groupers.try_assume_init().ok().unwrap(),259}260}261}262263struct ProbeState {264grouper_per_partition: Vec<Box<dyn Grouper>>,265}266267impl ProbeState {268/// Returns the max morsel sequence sent.269async fn partition_and_probe(270mut recv: Receiver<Morsel>,271mut send: Sender<Morsel>,272partitions: &[Box<dyn Grouper>],273partitioner: HashPartitioner,274params: &SemiAntiJoinParams,275state: &StreamingExecutionState,276) -> PolarsResult<()> {277let mut probe_match = Vec::new();278let key_selectors = if params.left_is_build {279¶ms.right_key_selectors280} else {281¶ms.left_key_selectors282};283284while let Ok(morsel) = recv.recv().await {285let (df, in_seq, src_token, wait_token) = morsel.into_inner();286if df.height() == 0 {287continue;288}289290let hash_keys =291select_keys(&df, key_selectors, params, &state.in_memory_exec_state).await?;292293unsafe {294let out_df = if params.return_bool {295let mut builder = BitmapBuilder::with_capacity(df.height());296partitions[0].contains_key_partitioned_groupers(297partitions,298&hash_keys,299&partitioner,300params.is_anti,301&mut builder,302);303let mut arr = BooleanArray::from(builder.freeze());304if !params.nulls_equal {305arr.set_validity(hash_keys.validity().cloned());306}307let s = BooleanChunked::with_chunk(df[0].name().clone(), arr).into_series();308DataFrame::new(vec![Column::from(s)])?309} else {310probe_match.clear();311partitions[0].probe_partitioned_groupers(312partitions,313&hash_keys,314&partitioner,315params.is_anti,316&mut probe_match,317);318if probe_match.is_empty() {319continue;320}321df.take_slice_unchecked(&probe_match)322};323324let mut morsel = Morsel::new(out_df, in_seq, src_token.clone());325if let Some(token) = wait_token {326morsel.set_consume_token(token);327}328if send.send(morsel).await.is_err() {329return Ok(());330}331}332}333334Ok(())335}336}337338impl ComputeNode for SemiAntiJoinNode {339fn name(&self) -> &str {340match (self.params.return_bool, self.params.is_anti) {341(false, false) => "semi-join",342(false, true) => "anti-join",343(true, false) => "is-in",344(true, true) => "is-not-in",345}346}347348fn update_state(349&mut self,350recv: &mut [PortState],351send: &mut [PortState],352_state: &StreamingExecutionState,353) -> PolarsResult<()> {354assert!(recv.len() == 2 && send.len() == 1);355356// If the output doesn't want any more data, transition to being done.357if send[0] == PortState::Done {358self.state = SemiAntiJoinState::Done;359}360361let build_idx = if self.params.left_is_build { 0 } else { 1 };362let probe_idx = 1 - build_idx;363364// If we are building and the build input is done, transition to probing.365if let SemiAntiJoinState::Build(build_state) = &mut self.state {366if recv[build_idx] == PortState::Done {367let probe_state = build_state.finalize(&*self.grouper);368self.state = SemiAntiJoinState::Probe(probe_state);369}370}371372// If we are probing and the probe input is done, we're done.373if let SemiAntiJoinState::Probe(_) = &mut self.state {374if recv[probe_idx] == PortState::Done {375self.state = SemiAntiJoinState::Done;376}377}378379match &mut self.state {380SemiAntiJoinState::Build(_) => {381send[0] = PortState::Blocked;382if recv[build_idx] != PortState::Done {383recv[build_idx] = PortState::Ready;384}385if recv[probe_idx] != PortState::Done {386recv[probe_idx] = PortState::Blocked;387}388},389SemiAntiJoinState::Probe(_) => {390if recv[probe_idx] != PortState::Done {391core::mem::swap(&mut send[0], &mut recv[probe_idx]);392} else {393send[0] = PortState::Done;394}395recv[build_idx] = PortState::Done;396},397SemiAntiJoinState::Done => {398send[0] = PortState::Done;399recv[0] = PortState::Done;400recv[1] = PortState::Done;401},402}403Ok(())404}405406fn is_memory_intensive_pipeline_blocker(&self) -> bool {407matches!(self.state, SemiAntiJoinState::Build { .. })408}409410fn spawn<'env, 's>(411&'env mut self,412scope: &'s TaskScope<'s, 'env>,413recv_ports: &mut [Option<RecvPort<'_>>],414send_ports: &mut [Option<SendPort<'_>>],415state: &'s StreamingExecutionState,416join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,417) {418assert!(recv_ports.len() == 2);419assert!(send_ports.len() == 1);420421let build_idx = if self.params.left_is_build { 0 } else { 1 };422let probe_idx = 1 - build_idx;423424match &mut self.state {425SemiAntiJoinState::Build(build_state) => {426assert!(send_ports[0].is_none());427assert!(recv_ports[probe_idx].is_none());428let receivers = recv_ports[build_idx].take().unwrap().parallel();429430let partitioner = HashPartitioner::new(state.num_pipelines, 0);431for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {432join_handles.push(scope.spawn_task(433TaskPriority::High,434BuildState::partition_and_sink(435recv,436local_builder,437partitioner.clone(),438&self.params,439state,440),441));442}443},444SemiAntiJoinState::Probe(probe_state) => {445assert!(recv_ports[build_idx].is_none());446let senders = send_ports[0].take().unwrap().parallel();447let receivers = recv_ports[probe_idx].take().unwrap().parallel();448449let partitioner = HashPartitioner::new(state.num_pipelines, 0);450for (recv, send) in receivers.into_iter().zip(senders) {451join_handles.push(scope.spawn_task(452TaskPriority::High,453ProbeState::partition_and_probe(454recv,455send,456&probe_state.grouper_per_partition,457partitioner.clone(),458&self.params,459state,460),461));462}463},464SemiAntiJoinState::Done => unreachable!(),465}466}467}468469470