Path: blob/main/crates/polars-stream/src/nodes/joins/semi_anti_join.rs
8512 views
use std::sync::Arc;12use arrow::array::BooleanArray;3use arrow::bitmap::BitmapBuilder;4use polars_core::prelude::*;5use polars_core::schema::Schema;6use polars_expr::groups::{Grouper, new_hash_grouper};7use polars_expr::hash_keys::HashKeys;8use polars_ops::frame::{JoinArgs, JoinType};9use polars_utils::IdxSize;10use polars_utils::cardinality_sketch::CardinalitySketch;11use polars_utils::hashing::HashPartitioner;12use polars_utils::itertools::Itertools;13use polars_utils::sparse_init_vec::SparseInitVec;1415use crate::async_executor;16use crate::expression::StreamExpr;17use crate::nodes::compute_node_prelude::*;1819async fn select_keys(20df: &DataFrame,21key_selectors: &[StreamExpr],22params: &SemiAntiJoinParams,23state: &ExecutionState,24) -> PolarsResult<HashKeys> {25let mut key_columns = Vec::new();26for selector in key_selectors {27key_columns.push(selector.evaluate(df, state).await?.into_column());28}29let keys = unsafe { DataFrame::new_unchecked_with_broadcast(df.height(), key_columns) }?;30Ok(HashKeys::from_df(31&keys,32params.random_state.clone(),33params.nulls_equal,34false,35))36}3738struct SemiAntiJoinParams {39left_is_build: bool,40left_key_selectors: Vec<StreamExpr>,41right_key_selectors: Vec<StreamExpr>,42nulls_equal: bool,43is_anti: bool,44return_bool: bool,45random_state: PlRandomState,46}4748pub struct SemiAntiJoinNode {49state: SemiAntiJoinState,50params: SemiAntiJoinParams,51grouper: Box<dyn Grouper>,52}5354impl SemiAntiJoinNode {55pub fn new(56unique_key_schema: Arc<Schema>,57left_key_selectors: Vec<StreamExpr>,58right_key_selectors: Vec<StreamExpr>,59args: JoinArgs,60return_bool: bool,61num_pipelines: usize,62) -> PolarsResult<Self> {63let left_is_build = false;64let is_anti = args.how == JoinType::Anti;6566let state = SemiAntiJoinState::Build(BuildState::new(num_pipelines, num_pipelines));6768Ok(Self {69state,70params: SemiAntiJoinParams {71left_is_build,72left_key_selectors,73right_key_selectors,74random_state: PlRandomState::default(),75nulls_equal: args.nulls_equal,76return_bool,77is_anti,78},79grouper: new_hash_grouper(unique_key_schema),80})81}82}8384enum SemiAntiJoinState {85Build(BuildState),86Probe(ProbeState),87Done,88}8990#[derive(Default)]91struct LocalBuilder {92// The complete list of keys as seen by this builder.93keys: Vec<HashKeys>,9495// A cardinality sketch per partition for the keys seen by this builder.96sketch_per_p: Vec<CardinalitySketch>,9798// key_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i]99// for partition p, where start, stop are:100// let start = key_idxs_offsets[i * num_partitions + p];101// let stop = key_idxs_offsets[(i + 1) * num_partitions + p];102key_idxs_values_per_p: Vec<Vec<IdxSize>>,103key_idxs_offsets_per_p: Vec<usize>,104}105106struct BuildState {107local_builders: Vec<LocalBuilder>,108}109110impl BuildState {111fn new(num_pipelines: usize, num_partitions: usize) -> Self {112let local_builders = (0..num_pipelines)113.map(|_| LocalBuilder {114keys: Vec::new(),115sketch_per_p: vec![CardinalitySketch::default(); num_partitions],116key_idxs_values_per_p: vec![Vec::new(); num_partitions],117key_idxs_offsets_per_p: vec![0; num_partitions],118})119.collect();120Self { local_builders }121}122123async fn partition_and_sink(124mut recv: PortReceiver,125local: &mut LocalBuilder,126partitioner: HashPartitioner,127params: &SemiAntiJoinParams,128state: &StreamingExecutionState,129) -> PolarsResult<()> {130let key_selectors = if params.left_is_build {131¶ms.left_key_selectors132} else {133¶ms.right_key_selectors134};135136while let Ok(morsel) = recv.recv().await {137let hash_keys = select_keys(138morsel.df(),139key_selectors,140params,141&state.in_memory_exec_state,142)143.await?;144145hash_keys.gen_idxs_per_partition(146&partitioner,147&mut local.key_idxs_values_per_p,148&mut local.sketch_per_p,149false,150);151152local153.key_idxs_offsets_per_p154.extend(local.key_idxs_values_per_p.iter().map(|vp| vp.len()));155local.keys.push(hash_keys);156}157Ok(())158}159160fn finalize(&mut self, grouper: &dyn Grouper) -> ProbeState {161// To reduce maximum memory usage we want to drop the original keys162// as soon as they're processed, so we move into Arcs. The drops might163// also be expensive, so instead of directly dropping we put that on164// a work queue.165let keys_per_local_builder = self166.local_builders167.iter_mut()168.map(|b| Arc::new(core::mem::take(&mut b.keys)))169.collect_vec();170let (key_drop_q_send, key_drop_q_recv) =171async_channel::bounded(keys_per_local_builder.len());172let num_partitions = self.local_builders[0].sketch_per_p.len();173let local_builders = &self.local_builders;174let groupers: SparseInitVec<Box<dyn Grouper>> =175SparseInitVec::with_capacity(num_partitions);176177async_executor::task_scope(|s| {178// Wrap in outer Arc to move to each thread, performing the179// expensive clone on that thread.180let arc_keys_per_local_builder = Arc::new(keys_per_local_builder);181let mut join_handles = Vec::new();182for p in 0..num_partitions {183let arc_keys_per_local_builder = Arc::clone(&arc_keys_per_local_builder);184let key_drop_q_send = key_drop_q_send.clone();185let key_drop_q_recv = key_drop_q_recv.clone();186let groupers = &groupers;187join_handles.push(s.spawn_task(TaskPriority::High, async move {188// Extract from outer arc and drop outer arc.189let keys_per_local_builder = Arc::unwrap_or_clone(arc_keys_per_local_builder);190191// Compute cardinality estimate.192let mut sketch = CardinalitySketch::new();193for l in local_builders {194sketch.combine(&l.sketch_per_p[p]);195}196197// Allocate hash table.198let mut p_grouper = grouper.new_empty();199p_grouper.reserve(sketch.estimate() * 5 / 4);200201// Build.202let mut skip_drop_attempt = false;203for (l, l_keys) in local_builders.iter().zip(keys_per_local_builder) {204// Try to help with dropping the processed keys.205if !skip_drop_attempt {206drop(key_drop_q_recv.try_recv());207}208209for (i, keys) in l_keys.iter().enumerate() {210unsafe {211let p_key_idxs_start =212l.key_idxs_offsets_per_p[i * num_partitions + p];213let p_key_idxs_stop =214l.key_idxs_offsets_per_p[(i + 1) * num_partitions + p];215let p_key_idxs =216&l.key_idxs_values_per_p[p][p_key_idxs_start..p_key_idxs_stop];217p_grouper.insert_keys_subset(keys, p_key_idxs, None);218}219}220221if let Some(l) = Arc::into_inner(l_keys) {222// If we're the last thread to process this set of keys we're probably223// falling behind the rest, since the drop can be quite expensive we skip224// a drop attempt hoping someone else will pick up the slack.225drop(key_drop_q_send.try_send(l));226skip_drop_attempt = true;227} else {228skip_drop_attempt = false;229}230}231232// We're done, help others out by doing drops.233drop(key_drop_q_send); // So we don't deadlock trying to receive from ourselves.234while let Ok(l_keys) = key_drop_q_recv.recv().await {235drop(l_keys);236}237238groupers.try_set(p, p_grouper).ok().unwrap();239}));240}241242// Drop outer arc after spawning each thread so the inner arcs243// can get dropped as soon as they're processed. We also have to244// drop the drop queue sender so we don't deadlock waiting for it245// to end.246drop(arc_keys_per_local_builder);247drop(key_drop_q_send);248249polars_io::pl_async::get_runtime().block_on(async move {250for handle in join_handles {251handle.await;252}253});254});255256ProbeState {257grouper_per_partition: groupers.try_assume_init().ok().unwrap(),258}259}260}261262struct ProbeState {263grouper_per_partition: Vec<Box<dyn Grouper>>,264}265266impl ProbeState {267/// Returns the max morsel sequence sent.268async fn partition_and_probe(269mut recv: PortReceiver,270mut send: PortSender,271partitions: &[Box<dyn Grouper>],272partitioner: HashPartitioner,273params: &SemiAntiJoinParams,274state: &StreamingExecutionState,275) -> PolarsResult<()> {276let mut probe_match = Vec::new();277let key_selectors = if params.left_is_build {278¶ms.right_key_selectors279} else {280¶ms.left_key_selectors281};282283while let Ok(morsel) = recv.recv().await {284let (df, in_seq, src_token, wait_token) = morsel.into_inner();285if df.height() == 0 {286continue;287}288289let hash_keys =290select_keys(&df, key_selectors, params, &state.in_memory_exec_state).await?;291292unsafe {293let out_df = if params.return_bool {294let mut builder = BitmapBuilder::with_capacity(df.height());295partitions[0].contains_key_partitioned_groupers(296partitions,297&hash_keys,298&partitioner,299params.is_anti,300&mut builder,301);302let mut arr = BooleanArray::from(builder.freeze());303if !params.nulls_equal {304arr.set_validity(hash_keys.validity().cloned());305}306let s = BooleanChunked::with_chunk(df[0].name().clone(), arr).into_series();307DataFrame::new_unchecked(s.len(), vec![Column::from(s)])308} else {309probe_match.clear();310partitions[0].probe_partitioned_groupers(311partitions,312&hash_keys,313&partitioner,314params.is_anti,315&mut probe_match,316);317if probe_match.is_empty() {318continue;319}320df.take_slice_unchecked(&probe_match)321};322323let mut morsel = Morsel::new(out_df, in_seq, src_token.clone());324if let Some(token) = wait_token {325morsel.set_consume_token(token);326}327if send.send(morsel).await.is_err() {328return Ok(());329}330}331}332333Ok(())334}335}336337impl ComputeNode for SemiAntiJoinNode {338fn name(&self) -> &str {339match (self.params.return_bool, self.params.is_anti) {340(false, false) => "semi-join",341(false, true) => "anti-join",342(true, false) => "is-in",343(true, true) => "is-not-in",344}345}346347fn update_state(348&mut self,349recv: &mut [PortState],350send: &mut [PortState],351_state: &StreamingExecutionState,352) -> PolarsResult<()> {353assert!(recv.len() == 2 && send.len() == 1);354355// If the output doesn't want any more data, transition to being done.356if send[0] == PortState::Done {357self.state = SemiAntiJoinState::Done;358}359360let build_idx = if self.params.left_is_build { 0 } else { 1 };361let probe_idx = 1 - build_idx;362363// If we are building and the build input is done, transition to probing.364if let SemiAntiJoinState::Build(build_state) = &mut self.state {365if recv[build_idx] == PortState::Done {366let probe_state = build_state.finalize(&*self.grouper);367self.state = SemiAntiJoinState::Probe(probe_state);368}369}370371// If we are probing and the probe input is done, we're done.372if let SemiAntiJoinState::Probe(_) = &mut self.state {373if recv[probe_idx] == PortState::Done {374self.state = SemiAntiJoinState::Done;375}376}377378match &mut self.state {379SemiAntiJoinState::Build(_) => {380send[0] = PortState::Blocked;381if recv[build_idx] != PortState::Done {382recv[build_idx] = PortState::Ready;383}384if recv[probe_idx] != PortState::Done {385recv[probe_idx] = PortState::Blocked;386}387},388SemiAntiJoinState::Probe(_) => {389if recv[probe_idx] != PortState::Done {390core::mem::swap(&mut send[0], &mut recv[probe_idx]);391} else {392send[0] = PortState::Done;393}394recv[build_idx] = PortState::Done;395},396SemiAntiJoinState::Done => {397send[0] = PortState::Done;398recv[0] = PortState::Done;399recv[1] = PortState::Done;400},401}402Ok(())403}404405fn is_memory_intensive_pipeline_blocker(&self) -> bool {406matches!(self.state, SemiAntiJoinState::Build { .. })407}408409fn spawn<'env, 's>(410&'env mut self,411scope: &'s TaskScope<'s, 'env>,412recv_ports: &mut [Option<RecvPort<'_>>],413send_ports: &mut [Option<SendPort<'_>>],414state: &'s StreamingExecutionState,415join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,416) {417assert!(recv_ports.len() == 2);418assert!(send_ports.len() == 1);419420let build_idx = if self.params.left_is_build { 0 } else { 1 };421let probe_idx = 1 - build_idx;422423match &mut self.state {424SemiAntiJoinState::Build(build_state) => {425assert!(send_ports[0].is_none());426assert!(recv_ports[probe_idx].is_none());427let receivers = recv_ports[build_idx].take().unwrap().parallel();428429let partitioner = HashPartitioner::new(state.num_pipelines, 0);430for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {431join_handles.push(scope.spawn_task(432TaskPriority::High,433BuildState::partition_and_sink(434recv,435local_builder,436partitioner.clone(),437&self.params,438state,439),440));441}442},443SemiAntiJoinState::Probe(probe_state) => {444assert!(recv_ports[build_idx].is_none());445let senders = send_ports[0].take().unwrap().parallel();446let receivers = recv_ports[probe_idx].take().unwrap().parallel();447448let partitioner = HashPartitioner::new(state.num_pipelines, 0);449for (recv, send) in receivers.into_iter().zip(senders) {450join_handles.push(scope.spawn_task(451TaskPriority::High,452ProbeState::partition_and_probe(453recv,454send,455&probe_state.grouper_per_partition,456partitioner.clone(),457&self.params,458state,459),460));461}462},463SemiAntiJoinState::Done => unreachable!(),464}465}466}467468469