Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/semi_anti_join.rs
6939 views
1
use std::sync::Arc;
2
3
use arrow::array::BooleanArray;
4
use arrow::bitmap::BitmapBuilder;
5
use polars_core::prelude::*;
6
use polars_core::schema::Schema;
7
use polars_expr::groups::{Grouper, new_hash_grouper};
8
use polars_expr::hash_keys::HashKeys;
9
use polars_ops::frame::{JoinArgs, JoinType};
10
use polars_utils::IdxSize;
11
use polars_utils::cardinality_sketch::CardinalitySketch;
12
use polars_utils::hashing::HashPartitioner;
13
use polars_utils::itertools::Itertools;
14
use polars_utils::sparse_init_vec::SparseInitVec;
15
16
use crate::async_executor;
17
use crate::async_primitives::connector::{Receiver, Sender};
18
use crate::expression::StreamExpr;
19
use crate::nodes::compute_node_prelude::*;
20
21
async fn select_keys(
22
df: &DataFrame,
23
key_selectors: &[StreamExpr],
24
params: &SemiAntiJoinParams,
25
state: &ExecutionState,
26
) -> PolarsResult<HashKeys> {
27
let mut key_columns = Vec::new();
28
for selector in key_selectors {
29
key_columns.push(selector.evaluate(df, state).await?.into_column());
30
}
31
let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?;
32
Ok(HashKeys::from_df(
33
&keys,
34
params.random_state,
35
params.nulls_equal,
36
false,
37
))
38
}
39
40
struct SemiAntiJoinParams {
41
left_is_build: bool,
42
left_key_selectors: Vec<StreamExpr>,
43
right_key_selectors: Vec<StreamExpr>,
44
nulls_equal: bool,
45
is_anti: bool,
46
return_bool: bool,
47
random_state: PlRandomState,
48
}
49
50
pub struct SemiAntiJoinNode {
51
state: SemiAntiJoinState,
52
params: SemiAntiJoinParams,
53
grouper: Box<dyn Grouper>,
54
}
55
56
impl SemiAntiJoinNode {
57
pub fn new(
58
unique_key_schema: Arc<Schema>,
59
left_key_selectors: Vec<StreamExpr>,
60
right_key_selectors: Vec<StreamExpr>,
61
args: JoinArgs,
62
return_bool: bool,
63
num_pipelines: usize,
64
) -> PolarsResult<Self> {
65
let left_is_build = false;
66
let is_anti = args.how == JoinType::Anti;
67
68
let state = SemiAntiJoinState::Build(BuildState::new(num_pipelines, num_pipelines));
69
70
Ok(Self {
71
state,
72
params: SemiAntiJoinParams {
73
left_is_build,
74
left_key_selectors,
75
right_key_selectors,
76
random_state: PlRandomState::default(),
77
nulls_equal: args.nulls_equal,
78
return_bool,
79
is_anti,
80
},
81
grouper: new_hash_grouper(unique_key_schema),
82
})
83
}
84
}
85
86
enum SemiAntiJoinState {
87
Build(BuildState),
88
Probe(ProbeState),
89
Done,
90
}
91
92
#[derive(Default)]
93
struct LocalBuilder {
94
// The complete list of keys as seen by this builder.
95
keys: Vec<HashKeys>,
96
97
// A cardinality sketch per partition for the keys seen by this builder.
98
sketch_per_p: Vec<CardinalitySketch>,
99
100
// key_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i]
101
// for partition p, where start, stop are:
102
// let start = key_idxs_offsets[i * num_partitions + p];
103
// let stop = key_idxs_offsets[(i + 1) * num_partitions + p];
104
key_idxs_values_per_p: Vec<Vec<IdxSize>>,
105
key_idxs_offsets_per_p: Vec<usize>,
106
}
107
108
struct BuildState {
109
local_builders: Vec<LocalBuilder>,
110
}
111
112
impl BuildState {
113
fn new(num_pipelines: usize, num_partitions: usize) -> Self {
114
let local_builders = (0..num_pipelines)
115
.map(|_| LocalBuilder {
116
keys: Vec::new(),
117
sketch_per_p: vec![CardinalitySketch::default(); num_partitions],
118
key_idxs_values_per_p: vec![Vec::new(); num_partitions],
119
key_idxs_offsets_per_p: vec![0; num_partitions],
120
})
121
.collect();
122
Self { local_builders }
123
}
124
125
async fn partition_and_sink(
126
mut recv: Receiver<Morsel>,
127
local: &mut LocalBuilder,
128
partitioner: HashPartitioner,
129
params: &SemiAntiJoinParams,
130
state: &StreamingExecutionState,
131
) -> PolarsResult<()> {
132
let key_selectors = if params.left_is_build {
133
&params.left_key_selectors
134
} else {
135
&params.right_key_selectors
136
};
137
138
while let Ok(morsel) = recv.recv().await {
139
let hash_keys = select_keys(
140
morsel.df(),
141
key_selectors,
142
params,
143
&state.in_memory_exec_state,
144
)
145
.await?;
146
147
hash_keys.gen_idxs_per_partition(
148
&partitioner,
149
&mut local.key_idxs_values_per_p,
150
&mut local.sketch_per_p,
151
false,
152
);
153
154
local
155
.key_idxs_offsets_per_p
156
.extend(local.key_idxs_values_per_p.iter().map(|vp| vp.len()));
157
local.keys.push(hash_keys);
158
}
159
Ok(())
160
}
161
162
fn finalize(&mut self, grouper: &dyn Grouper) -> ProbeState {
163
// To reduce maximum memory usage we want to drop the original keys
164
// as soon as they're processed, so we move into Arcs. The drops might
165
// also be expensive, so instead of directly dropping we put that on
166
// a work queue.
167
let keys_per_local_builder = self
168
.local_builders
169
.iter_mut()
170
.map(|b| Arc::new(core::mem::take(&mut b.keys)))
171
.collect_vec();
172
let (key_drop_q_send, key_drop_q_recv) =
173
async_channel::bounded(keys_per_local_builder.len());
174
let num_partitions = self.local_builders[0].sketch_per_p.len();
175
let local_builders = &self.local_builders;
176
let groupers: SparseInitVec<Box<dyn Grouper>> =
177
SparseInitVec::with_capacity(num_partitions);
178
179
async_executor::task_scope(|s| {
180
// Wrap in outer Arc to move to each thread, performing the
181
// expensive clone on that thread.
182
let arc_keys_per_local_builder = Arc::new(keys_per_local_builder);
183
let mut join_handles = Vec::new();
184
for p in 0..num_partitions {
185
let arc_keys_per_local_builder = Arc::clone(&arc_keys_per_local_builder);
186
let key_drop_q_send = key_drop_q_send.clone();
187
let key_drop_q_recv = key_drop_q_recv.clone();
188
let groupers = &groupers;
189
join_handles.push(s.spawn_task(TaskPriority::High, async move {
190
// Extract from outer arc and drop outer arc.
191
let keys_per_local_builder = Arc::unwrap_or_clone(arc_keys_per_local_builder);
192
193
// Compute cardinality estimate.
194
let mut sketch = CardinalitySketch::new();
195
for l in local_builders {
196
sketch.combine(&l.sketch_per_p[p]);
197
}
198
199
// Allocate hash table.
200
let mut p_grouper = grouper.new_empty();
201
p_grouper.reserve(sketch.estimate() * 5 / 4);
202
203
// Build.
204
let mut skip_drop_attempt = false;
205
for (l, l_keys) in local_builders.iter().zip(keys_per_local_builder) {
206
// Try to help with dropping the processed keys.
207
if !skip_drop_attempt {
208
drop(key_drop_q_recv.try_recv());
209
}
210
211
for (i, keys) in l_keys.iter().enumerate() {
212
unsafe {
213
let p_key_idxs_start =
214
l.key_idxs_offsets_per_p[i * num_partitions + p];
215
let p_key_idxs_stop =
216
l.key_idxs_offsets_per_p[(i + 1) * num_partitions + p];
217
let p_key_idxs =
218
&l.key_idxs_values_per_p[p][p_key_idxs_start..p_key_idxs_stop];
219
p_grouper.insert_keys_subset(keys, p_key_idxs, None);
220
}
221
}
222
223
if let Some(l) = Arc::into_inner(l_keys) {
224
// If we're the last thread to process this set of keys we're probably
225
// falling behind the rest, since the drop can be quite expensive we skip
226
// a drop attempt hoping someone else will pick up the slack.
227
drop(key_drop_q_send.try_send(l));
228
skip_drop_attempt = true;
229
} else {
230
skip_drop_attempt = false;
231
}
232
}
233
234
// We're done, help others out by doing drops.
235
drop(key_drop_q_send); // So we don't deadlock trying to receive from ourselves.
236
while let Ok(l_keys) = key_drop_q_recv.recv().await {
237
drop(l_keys);
238
}
239
240
groupers.try_set(p, p_grouper).ok().unwrap();
241
}));
242
}
243
244
// Drop outer arc after spawning each thread so the inner arcs
245
// can get dropped as soon as they're processed. We also have to
246
// drop the drop queue sender so we don't deadlock waiting for it
247
// to end.
248
drop(arc_keys_per_local_builder);
249
drop(key_drop_q_send);
250
251
polars_io::pl_async::get_runtime().block_on(async move {
252
for handle in join_handles {
253
handle.await;
254
}
255
});
256
});
257
258
ProbeState {
259
grouper_per_partition: groupers.try_assume_init().ok().unwrap(),
260
}
261
}
262
}
263
264
struct ProbeState {
265
grouper_per_partition: Vec<Box<dyn Grouper>>,
266
}
267
268
impl ProbeState {
269
/// Returns the max morsel sequence sent.
270
async fn partition_and_probe(
271
mut recv: Receiver<Morsel>,
272
mut send: Sender<Morsel>,
273
partitions: &[Box<dyn Grouper>],
274
partitioner: HashPartitioner,
275
params: &SemiAntiJoinParams,
276
state: &StreamingExecutionState,
277
) -> PolarsResult<()> {
278
let mut probe_match = Vec::new();
279
let key_selectors = if params.left_is_build {
280
&params.right_key_selectors
281
} else {
282
&params.left_key_selectors
283
};
284
285
while let Ok(morsel) = recv.recv().await {
286
let (df, in_seq, src_token, wait_token) = morsel.into_inner();
287
if df.height() == 0 {
288
continue;
289
}
290
291
let hash_keys =
292
select_keys(&df, key_selectors, params, &state.in_memory_exec_state).await?;
293
294
unsafe {
295
let out_df = if params.return_bool {
296
let mut builder = BitmapBuilder::with_capacity(df.height());
297
partitions[0].contains_key_partitioned_groupers(
298
partitions,
299
&hash_keys,
300
&partitioner,
301
params.is_anti,
302
&mut builder,
303
);
304
let mut arr = BooleanArray::from(builder.freeze());
305
if !params.nulls_equal {
306
arr.set_validity(hash_keys.validity().cloned());
307
}
308
let s = BooleanChunked::with_chunk(df[0].name().clone(), arr).into_series();
309
DataFrame::new(vec![Column::from(s)])?
310
} else {
311
probe_match.clear();
312
partitions[0].probe_partitioned_groupers(
313
partitions,
314
&hash_keys,
315
&partitioner,
316
params.is_anti,
317
&mut probe_match,
318
);
319
if probe_match.is_empty() {
320
continue;
321
}
322
df.take_slice_unchecked(&probe_match)
323
};
324
325
let mut morsel = Morsel::new(out_df, in_seq, src_token.clone());
326
if let Some(token) = wait_token {
327
morsel.set_consume_token(token);
328
}
329
if send.send(morsel).await.is_err() {
330
return Ok(());
331
}
332
}
333
}
334
335
Ok(())
336
}
337
}
338
339
impl ComputeNode for SemiAntiJoinNode {
340
fn name(&self) -> &str {
341
match (self.params.return_bool, self.params.is_anti) {
342
(false, false) => "semi-join",
343
(false, true) => "anti-join",
344
(true, false) => "is-in",
345
(true, true) => "is-not-in",
346
}
347
}
348
349
fn update_state(
350
&mut self,
351
recv: &mut [PortState],
352
send: &mut [PortState],
353
_state: &StreamingExecutionState,
354
) -> PolarsResult<()> {
355
assert!(recv.len() == 2 && send.len() == 1);
356
357
// If the output doesn't want any more data, transition to being done.
358
if send[0] == PortState::Done {
359
self.state = SemiAntiJoinState::Done;
360
}
361
362
let build_idx = if self.params.left_is_build { 0 } else { 1 };
363
let probe_idx = 1 - build_idx;
364
365
// If we are building and the build input is done, transition to probing.
366
if let SemiAntiJoinState::Build(build_state) = &mut self.state {
367
if recv[build_idx] == PortState::Done {
368
let probe_state = build_state.finalize(&*self.grouper);
369
self.state = SemiAntiJoinState::Probe(probe_state);
370
}
371
}
372
373
// If we are probing and the probe input is done, we're done.
374
if let SemiAntiJoinState::Probe(_) = &mut self.state {
375
if recv[probe_idx] == PortState::Done {
376
self.state = SemiAntiJoinState::Done;
377
}
378
}
379
380
match &mut self.state {
381
SemiAntiJoinState::Build(_) => {
382
send[0] = PortState::Blocked;
383
if recv[build_idx] != PortState::Done {
384
recv[build_idx] = PortState::Ready;
385
}
386
if recv[probe_idx] != PortState::Done {
387
recv[probe_idx] = PortState::Blocked;
388
}
389
},
390
SemiAntiJoinState::Probe(_) => {
391
if recv[probe_idx] != PortState::Done {
392
core::mem::swap(&mut send[0], &mut recv[probe_idx]);
393
} else {
394
send[0] = PortState::Done;
395
}
396
recv[build_idx] = PortState::Done;
397
},
398
SemiAntiJoinState::Done => {
399
send[0] = PortState::Done;
400
recv[0] = PortState::Done;
401
recv[1] = PortState::Done;
402
},
403
}
404
Ok(())
405
}
406
407
fn is_memory_intensive_pipeline_blocker(&self) -> bool {
408
matches!(self.state, SemiAntiJoinState::Build { .. })
409
}
410
411
fn spawn<'env, 's>(
412
&'env mut self,
413
scope: &'s TaskScope<'s, 'env>,
414
recv_ports: &mut [Option<RecvPort<'_>>],
415
send_ports: &mut [Option<SendPort<'_>>],
416
state: &'s StreamingExecutionState,
417
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
418
) {
419
assert!(recv_ports.len() == 2);
420
assert!(send_ports.len() == 1);
421
422
let build_idx = if self.params.left_is_build { 0 } else { 1 };
423
let probe_idx = 1 - build_idx;
424
425
match &mut self.state {
426
SemiAntiJoinState::Build(build_state) => {
427
assert!(send_ports[0].is_none());
428
assert!(recv_ports[probe_idx].is_none());
429
let receivers = recv_ports[build_idx].take().unwrap().parallel();
430
431
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
432
for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {
433
join_handles.push(scope.spawn_task(
434
TaskPriority::High,
435
BuildState::partition_and_sink(
436
recv,
437
local_builder,
438
partitioner.clone(),
439
&self.params,
440
state,
441
),
442
));
443
}
444
},
445
SemiAntiJoinState::Probe(probe_state) => {
446
assert!(recv_ports[build_idx].is_none());
447
let senders = send_ports[0].take().unwrap().parallel();
448
let receivers = recv_ports[probe_idx].take().unwrap().parallel();
449
450
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
451
for (recv, send) in receivers.into_iter().zip(senders) {
452
join_handles.push(scope.spawn_task(
453
TaskPriority::High,
454
ProbeState::partition_and_probe(
455
recv,
456
send,
457
&probe_state.grouper_per_partition,
458
partitioner.clone(),
459
&self.params,
460
state,
461
),
462
));
463
}
464
},
465
SemiAntiJoinState::Done => unreachable!(),
466
}
467
}
468
}
469
470