Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/semi_anti_join.rs
8512 views
1
use std::sync::Arc;
2
3
use arrow::array::BooleanArray;
4
use arrow::bitmap::BitmapBuilder;
5
use polars_core::prelude::*;
6
use polars_core::schema::Schema;
7
use polars_expr::groups::{Grouper, new_hash_grouper};
8
use polars_expr::hash_keys::HashKeys;
9
use polars_ops::frame::{JoinArgs, JoinType};
10
use polars_utils::IdxSize;
11
use polars_utils::cardinality_sketch::CardinalitySketch;
12
use polars_utils::hashing::HashPartitioner;
13
use polars_utils::itertools::Itertools;
14
use polars_utils::sparse_init_vec::SparseInitVec;
15
16
use crate::async_executor;
17
use crate::expression::StreamExpr;
18
use crate::nodes::compute_node_prelude::*;
19
20
async fn select_keys(
21
df: &DataFrame,
22
key_selectors: &[StreamExpr],
23
params: &SemiAntiJoinParams,
24
state: &ExecutionState,
25
) -> PolarsResult<HashKeys> {
26
let mut key_columns = Vec::new();
27
for selector in key_selectors {
28
key_columns.push(selector.evaluate(df, state).await?.into_column());
29
}
30
let keys = unsafe { DataFrame::new_unchecked_with_broadcast(df.height(), key_columns) }?;
31
Ok(HashKeys::from_df(
32
&keys,
33
params.random_state.clone(),
34
params.nulls_equal,
35
false,
36
))
37
}
38
39
struct SemiAntiJoinParams {
40
left_is_build: bool,
41
left_key_selectors: Vec<StreamExpr>,
42
right_key_selectors: Vec<StreamExpr>,
43
nulls_equal: bool,
44
is_anti: bool,
45
return_bool: bool,
46
random_state: PlRandomState,
47
}
48
49
pub struct SemiAntiJoinNode {
50
state: SemiAntiJoinState,
51
params: SemiAntiJoinParams,
52
grouper: Box<dyn Grouper>,
53
}
54
55
impl SemiAntiJoinNode {
56
pub fn new(
57
unique_key_schema: Arc<Schema>,
58
left_key_selectors: Vec<StreamExpr>,
59
right_key_selectors: Vec<StreamExpr>,
60
args: JoinArgs,
61
return_bool: bool,
62
num_pipelines: usize,
63
) -> PolarsResult<Self> {
64
let left_is_build = false;
65
let is_anti = args.how == JoinType::Anti;
66
67
let state = SemiAntiJoinState::Build(BuildState::new(num_pipelines, num_pipelines));
68
69
Ok(Self {
70
state,
71
params: SemiAntiJoinParams {
72
left_is_build,
73
left_key_selectors,
74
right_key_selectors,
75
random_state: PlRandomState::default(),
76
nulls_equal: args.nulls_equal,
77
return_bool,
78
is_anti,
79
},
80
grouper: new_hash_grouper(unique_key_schema),
81
})
82
}
83
}
84
85
enum SemiAntiJoinState {
86
Build(BuildState),
87
Probe(ProbeState),
88
Done,
89
}
90
91
#[derive(Default)]
92
struct LocalBuilder {
93
// The complete list of keys as seen by this builder.
94
keys: Vec<HashKeys>,
95
96
// A cardinality sketch per partition for the keys seen by this builder.
97
sketch_per_p: Vec<CardinalitySketch>,
98
99
// key_idxs_values_per_p[p][start..stop] contains the offsets into morsels[i]
100
// for partition p, where start, stop are:
101
// let start = key_idxs_offsets[i * num_partitions + p];
102
// let stop = key_idxs_offsets[(i + 1) * num_partitions + p];
103
key_idxs_values_per_p: Vec<Vec<IdxSize>>,
104
key_idxs_offsets_per_p: Vec<usize>,
105
}
106
107
struct BuildState {
108
local_builders: Vec<LocalBuilder>,
109
}
110
111
impl BuildState {
112
fn new(num_pipelines: usize, num_partitions: usize) -> Self {
113
let local_builders = (0..num_pipelines)
114
.map(|_| LocalBuilder {
115
keys: Vec::new(),
116
sketch_per_p: vec![CardinalitySketch::default(); num_partitions],
117
key_idxs_values_per_p: vec![Vec::new(); num_partitions],
118
key_idxs_offsets_per_p: vec![0; num_partitions],
119
})
120
.collect();
121
Self { local_builders }
122
}
123
124
async fn partition_and_sink(
125
mut recv: PortReceiver,
126
local: &mut LocalBuilder,
127
partitioner: HashPartitioner,
128
params: &SemiAntiJoinParams,
129
state: &StreamingExecutionState,
130
) -> PolarsResult<()> {
131
let key_selectors = if params.left_is_build {
132
&params.left_key_selectors
133
} else {
134
&params.right_key_selectors
135
};
136
137
while let Ok(morsel) = recv.recv().await {
138
let hash_keys = select_keys(
139
morsel.df(),
140
key_selectors,
141
params,
142
&state.in_memory_exec_state,
143
)
144
.await?;
145
146
hash_keys.gen_idxs_per_partition(
147
&partitioner,
148
&mut local.key_idxs_values_per_p,
149
&mut local.sketch_per_p,
150
false,
151
);
152
153
local
154
.key_idxs_offsets_per_p
155
.extend(local.key_idxs_values_per_p.iter().map(|vp| vp.len()));
156
local.keys.push(hash_keys);
157
}
158
Ok(())
159
}
160
161
fn finalize(&mut self, grouper: &dyn Grouper) -> ProbeState {
162
// To reduce maximum memory usage we want to drop the original keys
163
// as soon as they're processed, so we move into Arcs. The drops might
164
// also be expensive, so instead of directly dropping we put that on
165
// a work queue.
166
let keys_per_local_builder = self
167
.local_builders
168
.iter_mut()
169
.map(|b| Arc::new(core::mem::take(&mut b.keys)))
170
.collect_vec();
171
let (key_drop_q_send, key_drop_q_recv) =
172
async_channel::bounded(keys_per_local_builder.len());
173
let num_partitions = self.local_builders[0].sketch_per_p.len();
174
let local_builders = &self.local_builders;
175
let groupers: SparseInitVec<Box<dyn Grouper>> =
176
SparseInitVec::with_capacity(num_partitions);
177
178
async_executor::task_scope(|s| {
179
// Wrap in outer Arc to move to each thread, performing the
180
// expensive clone on that thread.
181
let arc_keys_per_local_builder = Arc::new(keys_per_local_builder);
182
let mut join_handles = Vec::new();
183
for p in 0..num_partitions {
184
let arc_keys_per_local_builder = Arc::clone(&arc_keys_per_local_builder);
185
let key_drop_q_send = key_drop_q_send.clone();
186
let key_drop_q_recv = key_drop_q_recv.clone();
187
let groupers = &groupers;
188
join_handles.push(s.spawn_task(TaskPriority::High, async move {
189
// Extract from outer arc and drop outer arc.
190
let keys_per_local_builder = Arc::unwrap_or_clone(arc_keys_per_local_builder);
191
192
// Compute cardinality estimate.
193
let mut sketch = CardinalitySketch::new();
194
for l in local_builders {
195
sketch.combine(&l.sketch_per_p[p]);
196
}
197
198
// Allocate hash table.
199
let mut p_grouper = grouper.new_empty();
200
p_grouper.reserve(sketch.estimate() * 5 / 4);
201
202
// Build.
203
let mut skip_drop_attempt = false;
204
for (l, l_keys) in local_builders.iter().zip(keys_per_local_builder) {
205
// Try to help with dropping the processed keys.
206
if !skip_drop_attempt {
207
drop(key_drop_q_recv.try_recv());
208
}
209
210
for (i, keys) in l_keys.iter().enumerate() {
211
unsafe {
212
let p_key_idxs_start =
213
l.key_idxs_offsets_per_p[i * num_partitions + p];
214
let p_key_idxs_stop =
215
l.key_idxs_offsets_per_p[(i + 1) * num_partitions + p];
216
let p_key_idxs =
217
&l.key_idxs_values_per_p[p][p_key_idxs_start..p_key_idxs_stop];
218
p_grouper.insert_keys_subset(keys, p_key_idxs, None);
219
}
220
}
221
222
if let Some(l) = Arc::into_inner(l_keys) {
223
// If we're the last thread to process this set of keys we're probably
224
// falling behind the rest, since the drop can be quite expensive we skip
225
// a drop attempt hoping someone else will pick up the slack.
226
drop(key_drop_q_send.try_send(l));
227
skip_drop_attempt = true;
228
} else {
229
skip_drop_attempt = false;
230
}
231
}
232
233
// We're done, help others out by doing drops.
234
drop(key_drop_q_send); // So we don't deadlock trying to receive from ourselves.
235
while let Ok(l_keys) = key_drop_q_recv.recv().await {
236
drop(l_keys);
237
}
238
239
groupers.try_set(p, p_grouper).ok().unwrap();
240
}));
241
}
242
243
// Drop outer arc after spawning each thread so the inner arcs
244
// can get dropped as soon as they're processed. We also have to
245
// drop the drop queue sender so we don't deadlock waiting for it
246
// to end.
247
drop(arc_keys_per_local_builder);
248
drop(key_drop_q_send);
249
250
polars_io::pl_async::get_runtime().block_on(async move {
251
for handle in join_handles {
252
handle.await;
253
}
254
});
255
});
256
257
ProbeState {
258
grouper_per_partition: groupers.try_assume_init().ok().unwrap(),
259
}
260
}
261
}
262
263
struct ProbeState {
264
grouper_per_partition: Vec<Box<dyn Grouper>>,
265
}
266
267
impl ProbeState {
268
/// Returns the max morsel sequence sent.
269
async fn partition_and_probe(
270
mut recv: PortReceiver,
271
mut send: PortSender,
272
partitions: &[Box<dyn Grouper>],
273
partitioner: HashPartitioner,
274
params: &SemiAntiJoinParams,
275
state: &StreamingExecutionState,
276
) -> PolarsResult<()> {
277
let mut probe_match = Vec::new();
278
let key_selectors = if params.left_is_build {
279
&params.right_key_selectors
280
} else {
281
&params.left_key_selectors
282
};
283
284
while let Ok(morsel) = recv.recv().await {
285
let (df, in_seq, src_token, wait_token) = morsel.into_inner();
286
if df.height() == 0 {
287
continue;
288
}
289
290
let hash_keys =
291
select_keys(&df, key_selectors, params, &state.in_memory_exec_state).await?;
292
293
unsafe {
294
let out_df = if params.return_bool {
295
let mut builder = BitmapBuilder::with_capacity(df.height());
296
partitions[0].contains_key_partitioned_groupers(
297
partitions,
298
&hash_keys,
299
&partitioner,
300
params.is_anti,
301
&mut builder,
302
);
303
let mut arr = BooleanArray::from(builder.freeze());
304
if !params.nulls_equal {
305
arr.set_validity(hash_keys.validity().cloned());
306
}
307
let s = BooleanChunked::with_chunk(df[0].name().clone(), arr).into_series();
308
DataFrame::new_unchecked(s.len(), vec![Column::from(s)])
309
} else {
310
probe_match.clear();
311
partitions[0].probe_partitioned_groupers(
312
partitions,
313
&hash_keys,
314
&partitioner,
315
params.is_anti,
316
&mut probe_match,
317
);
318
if probe_match.is_empty() {
319
continue;
320
}
321
df.take_slice_unchecked(&probe_match)
322
};
323
324
let mut morsel = Morsel::new(out_df, in_seq, src_token.clone());
325
if let Some(token) = wait_token {
326
morsel.set_consume_token(token);
327
}
328
if send.send(morsel).await.is_err() {
329
return Ok(());
330
}
331
}
332
}
333
334
Ok(())
335
}
336
}
337
338
impl ComputeNode for SemiAntiJoinNode {
339
fn name(&self) -> &str {
340
match (self.params.return_bool, self.params.is_anti) {
341
(false, false) => "semi-join",
342
(false, true) => "anti-join",
343
(true, false) => "is-in",
344
(true, true) => "is-not-in",
345
}
346
}
347
348
fn update_state(
349
&mut self,
350
recv: &mut [PortState],
351
send: &mut [PortState],
352
_state: &StreamingExecutionState,
353
) -> PolarsResult<()> {
354
assert!(recv.len() == 2 && send.len() == 1);
355
356
// If the output doesn't want any more data, transition to being done.
357
if send[0] == PortState::Done {
358
self.state = SemiAntiJoinState::Done;
359
}
360
361
let build_idx = if self.params.left_is_build { 0 } else { 1 };
362
let probe_idx = 1 - build_idx;
363
364
// If we are building and the build input is done, transition to probing.
365
if let SemiAntiJoinState::Build(build_state) = &mut self.state {
366
if recv[build_idx] == PortState::Done {
367
let probe_state = build_state.finalize(&*self.grouper);
368
self.state = SemiAntiJoinState::Probe(probe_state);
369
}
370
}
371
372
// If we are probing and the probe input is done, we're done.
373
if let SemiAntiJoinState::Probe(_) = &mut self.state {
374
if recv[probe_idx] == PortState::Done {
375
self.state = SemiAntiJoinState::Done;
376
}
377
}
378
379
match &mut self.state {
380
SemiAntiJoinState::Build(_) => {
381
send[0] = PortState::Blocked;
382
if recv[build_idx] != PortState::Done {
383
recv[build_idx] = PortState::Ready;
384
}
385
if recv[probe_idx] != PortState::Done {
386
recv[probe_idx] = PortState::Blocked;
387
}
388
},
389
SemiAntiJoinState::Probe(_) => {
390
if recv[probe_idx] != PortState::Done {
391
core::mem::swap(&mut send[0], &mut recv[probe_idx]);
392
} else {
393
send[0] = PortState::Done;
394
}
395
recv[build_idx] = PortState::Done;
396
},
397
SemiAntiJoinState::Done => {
398
send[0] = PortState::Done;
399
recv[0] = PortState::Done;
400
recv[1] = PortState::Done;
401
},
402
}
403
Ok(())
404
}
405
406
fn is_memory_intensive_pipeline_blocker(&self) -> bool {
407
matches!(self.state, SemiAntiJoinState::Build { .. })
408
}
409
410
fn spawn<'env, 's>(
411
&'env mut self,
412
scope: &'s TaskScope<'s, 'env>,
413
recv_ports: &mut [Option<RecvPort<'_>>],
414
send_ports: &mut [Option<SendPort<'_>>],
415
state: &'s StreamingExecutionState,
416
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
417
) {
418
assert!(recv_ports.len() == 2);
419
assert!(send_ports.len() == 1);
420
421
let build_idx = if self.params.left_is_build { 0 } else { 1 };
422
let probe_idx = 1 - build_idx;
423
424
match &mut self.state {
425
SemiAntiJoinState::Build(build_state) => {
426
assert!(send_ports[0].is_none());
427
assert!(recv_ports[probe_idx].is_none());
428
let receivers = recv_ports[build_idx].take().unwrap().parallel();
429
430
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
431
for (local_builder, recv) in build_state.local_builders.iter_mut().zip(receivers) {
432
join_handles.push(scope.spawn_task(
433
TaskPriority::High,
434
BuildState::partition_and_sink(
435
recv,
436
local_builder,
437
partitioner.clone(),
438
&self.params,
439
state,
440
),
441
));
442
}
443
},
444
SemiAntiJoinState::Probe(probe_state) => {
445
assert!(recv_ports[build_idx].is_none());
446
let senders = send_ports[0].take().unwrap().parallel();
447
let receivers = recv_ports[probe_idx].take().unwrap().parallel();
448
449
let partitioner = HashPartitioner::new(state.num_pipelines, 0);
450
for (recv, send) in receivers.into_iter().zip(senders) {
451
join_handles.push(scope.spawn_task(
452
TaskPriority::High,
453
ProbeState::partition_and_probe(
454
recv,
455
send,
456
&probe_state.grouper_per_partition,
457
partitioner.clone(),
458
&self.params,
459
state,
460
),
461
));
462
}
463
},
464
SemiAntiJoinState::Done => unreachable!(),
465
}
466
}
467
}
468
469