Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/iejoin/mod.rs
8458 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
mod filtered_bit_array;
3
mod l1_l2;
4
5
use std::cmp::min;
6
7
use filtered_bit_array::FilteredBitArray;
8
use l1_l2::*;
9
use polars_core::chunked_array::ChunkedArray;
10
use polars_core::datatypes::{IdxCa, NumericNative, PolarsNumericType};
11
use polars_core::frame::DataFrame;
12
use polars_core::prelude::*;
13
use polars_core::series::IsSorted;
14
use polars_core::utils::{_set_partition_size, split};
15
use polars_core::{POOL, with_match_physical_numeric_polars_type};
16
use polars_error::{PolarsResult, polars_err};
17
use polars_utils::IdxSize;
18
use polars_utils::binary_search::ExponentialSearch;
19
use polars_utils::itertools::Itertools;
20
use polars_utils::total_ord::{TotalEq, TotalOrd};
21
use rayon::prelude::*;
22
#[cfg(feature = "serde")]
23
use serde::{Deserialize, Serialize};
24
25
use crate::frame::_finish_join;
26
27
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
28
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
30
pub enum InequalityOperator {
31
#[default]
32
Lt,
33
LtEq,
34
Gt,
35
GtEq,
36
}
37
38
impl InequalityOperator {
39
fn is_strict(&self) -> bool {
40
matches!(self, InequalityOperator::Gt | InequalityOperator::Lt)
41
}
42
}
43
#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)]
44
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45
pub struct IEJoinOptions {
46
pub operator1: InequalityOperator,
47
pub operator2: Option<InequalityOperator>,
48
}
49
50
#[allow(clippy::too_many_arguments)]
51
fn ie_join_impl_t<T: PolarsNumericType>(
52
slice: Option<(i64, usize)>,
53
l1_order: IdxCa,
54
l2_order: &[IdxSize],
55
op1: InequalityOperator,
56
op2: InequalityOperator,
57
x: Series,
58
y_ordered_by_x: Series,
59
left_height: usize,
60
) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)> {
61
// Create a bit array with order corresponding to L1,
62
// denoting which entries have been visited while traversing L2.
63
let mut bit_array = FilteredBitArray::from_len_zeroed(l1_order.len());
64
65
let mut left_row_idx: Vec<IdxSize> = vec![];
66
let mut right_row_idx: Vec<IdxSize> = vec![];
67
68
let slice_end = slice_end_index(slice);
69
let mut match_count = 0;
70
71
let ca: &ChunkedArray<T> = x.as_ref().as_ref();
72
let l1_array = build_l1_array(ca, &l1_order, left_height as IdxSize)?;
73
74
if op2.is_strict() {
75
// For strict inequalities, we rely on using a stable sort of l2 so that
76
// p values only increase as we traverse a run of equal y values.
77
// To handle inclusive comparisons in x and duplicate x values we also need the
78
// sort of l1 to be stable, so that the left hand side entries come before the right
79
// hand side entries (as we mark visited entries from the right hand side).
80
for &p in l2_order {
81
match_count += unsafe {
82
l1_array.process_entry(
83
p as usize,
84
&mut bit_array,
85
op1,
86
&mut left_row_idx,
87
&mut right_row_idx,
88
)
89
};
90
91
if slice_end.is_some_and(|end| match_count >= end) {
92
break;
93
}
94
}
95
} else {
96
let l2_array = build_l2_array(&y_ordered_by_x, l2_order)?;
97
98
// For non-strict inequalities in l2, we need to track runs of equal y values and only
99
// check for matches after we reach the end of the run and have marked all rhs entries
100
// in the run as visited.
101
let mut run_start = 0;
102
103
for i in 0..l2_array.len() {
104
// Elide bound checks
105
unsafe {
106
let item = l2_array.get_unchecked(i);
107
let p = item.l1_index;
108
l1_array.mark_visited(p as usize, &mut bit_array);
109
110
if item.run_end {
111
for l2_item in l2_array.get_unchecked(run_start..i + 1) {
112
let p = l2_item.l1_index;
113
match_count += l1_array.process_lhs_entry(
114
p as usize,
115
&bit_array,
116
op1,
117
&mut left_row_idx,
118
&mut right_row_idx,
119
);
120
}
121
122
run_start = i + 1;
123
124
if slice_end.is_some_and(|end| match_count >= end) {
125
break;
126
}
127
}
128
}
129
}
130
}
131
Ok((left_row_idx, right_row_idx))
132
}
133
134
fn piecewise_merge_join_impl_t<T, P>(
135
slice: Option<(i64, usize)>,
136
left_order: Option<&[IdxSize]>,
137
right_order: Option<&[IdxSize]>,
138
left_ordered: Series,
139
right_ordered: Series,
140
mut pred: P,
141
) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>
142
where
143
T: PolarsNumericType,
144
P: FnMut(&T::Native, &T::Native) -> bool,
145
{
146
let slice_end = slice_end_index(slice);
147
148
let mut left_row_idx: Vec<IdxSize> = vec![];
149
let mut right_row_idx: Vec<IdxSize> = vec![];
150
151
let left_ca: &ChunkedArray<T> = left_ordered.as_ref().as_ref();
152
let right_ca: &ChunkedArray<T> = right_ordered.as_ref().as_ref();
153
154
debug_assert!(left_order.is_none_or(|order| order.len() == left_ca.len()));
155
debug_assert!(right_order.is_none_or(|order| order.len() == right_ca.len()));
156
157
let mut left_idx = 0;
158
let mut right_idx = 0;
159
let mut match_count = 0;
160
161
while left_idx < left_ca.len() {
162
debug_assert!(left_ca.get(left_idx).is_some());
163
let left_val = unsafe { left_ca.value_unchecked(left_idx) };
164
while right_idx < right_ca.len() {
165
debug_assert!(right_ca.get(right_idx).is_some());
166
let right_val = unsafe { right_ca.value_unchecked(right_idx) };
167
if pred(&left_val, &right_val) {
168
// If the predicate is true, then it will also be true for all
169
// remaining rows from the right side.
170
let left_row = match left_order {
171
None => left_idx as IdxSize,
172
Some(order) => order[left_idx],
173
};
174
let right_end_idx = match slice_end {
175
None => right_ca.len(),
176
Some(end) => min(right_ca.len(), (end as usize) - match_count + right_idx),
177
};
178
for included_right_row_idx in right_idx..right_end_idx {
179
let right_row = match right_order {
180
None => included_right_row_idx as IdxSize,
181
Some(order) => order[included_right_row_idx],
182
};
183
left_row_idx.push(left_row);
184
right_row_idx.push(right_row);
185
}
186
match_count += right_end_idx - right_idx;
187
break;
188
} else {
189
right_idx += 1;
190
}
191
}
192
if right_idx == right_ca.len() {
193
// We've reached the end of the right side
194
// so there can be no more matches for LHS rows
195
break;
196
}
197
if slice_end.is_some_and(|end| match_count >= end as usize) {
198
break;
199
}
200
left_idx += 1;
201
}
202
203
Ok((left_row_idx, right_row_idx))
204
}
205
206
pub(super) fn iejoin_par(
207
left: &DataFrame,
208
right: &DataFrame,
209
selected_left: Vec<Series>,
210
selected_right: Vec<Series>,
211
options: &IEJoinOptions,
212
suffix: Option<PlSmallStr>,
213
slice: Option<(i64, usize)>,
214
) -> PolarsResult<DataFrame> {
215
let l1_descending = matches!(
216
options.operator1,
217
InequalityOperator::Gt | InequalityOperator::GtEq
218
);
219
220
let l1_sort_options = SortOptions::default()
221
.with_maintain_order(true)
222
.with_nulls_last(false)
223
.with_order_descending(l1_descending);
224
225
let sl = &selected_left[0];
226
let l1_s_l = sl
227
.arg_sort(l1_sort_options)
228
.slice(sl.null_count() as i64, sl.len() - sl.null_count());
229
230
let sr = &selected_right[0];
231
let l1_s_r = sr
232
.arg_sort(l1_sort_options)
233
.slice(sr.null_count() as i64, sr.len() - sr.null_count());
234
235
// Because we do a cartesian product, the number of partitions is squared.
236
// We take the sqrt, but we don't expect every partition to produce results and work can be
237
// imbalanced, so we multiply the number of partitions by 2, which leads to 2^2= 4
238
let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2;
239
let splitted_a = split(&l1_s_l, n_partitions);
240
let splitted_b = split(&l1_s_r, n_partitions);
241
242
let cartesian_prod = splitted_a
243
.iter()
244
.flat_map(|l| splitted_b.iter().map(move |r| (l, r)))
245
.collect::<Vec<_>>();
246
247
let iter = cartesian_prod.par_iter().map(|(l_l1_idx, r_l1_idx)| {
248
if l_l1_idx.is_empty() || r_l1_idx.is_empty() {
249
return Ok(None);
250
}
251
fn get_extrema<'a>(
252
l1_idx: &'a IdxCa,
253
s: &'a Series,
254
) -> Option<(AnyValue<'a>, AnyValue<'a>)> {
255
let first = l1_idx.first()?;
256
let last = l1_idx.last()?;
257
258
let start = s.get(first as usize).unwrap();
259
let end = s.get(last as usize).unwrap();
260
261
Some(if start < end {
262
(start, end)
263
} else {
264
(end, start)
265
})
266
}
267
let Some((min_l, max_l)) = get_extrema(l_l1_idx, sl) else {
268
return Ok(None);
269
};
270
let Some((min_r, max_r)) = get_extrema(r_l1_idx, sr) else {
271
return Ok(None);
272
};
273
274
let include_block = match options.operator1 {
275
InequalityOperator::Lt => min_l < max_r,
276
InequalityOperator::LtEq => min_l <= max_r,
277
InequalityOperator::Gt => max_l > min_r,
278
InequalityOperator::GtEq => max_l >= min_r,
279
};
280
281
if include_block {
282
let (mut l, mut r) = unsafe {
283
(
284
selected_left
285
.iter()
286
.map(|s| s.take_unchecked(l_l1_idx))
287
.collect_vec(),
288
selected_right
289
.iter()
290
.map(|s| s.take_unchecked(r_l1_idx))
291
.collect_vec(),
292
)
293
};
294
let sorted_flag = if l1_descending {
295
IsSorted::Descending
296
} else {
297
IsSorted::Ascending
298
};
299
// We sorted using the first series
300
l[0].set_sorted_flag(sorted_flag);
301
r[0].set_sorted_flag(sorted_flag);
302
303
// Compute the row indexes
304
let (idx_l, idx_r) = if options.operator2.is_some() {
305
iejoin_tuples(l, r, options, None)
306
} else {
307
piecewise_merge_join_tuples(l, r, options, None)
308
}?;
309
310
if idx_l.is_empty() {
311
return Ok(None);
312
}
313
314
// These are row indexes in the slices we have given, so we use those to gather in the
315
// original l1 offset arrays. This gives us indexes in the original tables.
316
unsafe {
317
Ok(Some((
318
l_l1_idx.take_unchecked(&idx_l),
319
r_l1_idx.take_unchecked(&idx_r),
320
)))
321
}
322
} else {
323
Ok(None)
324
}
325
});
326
327
let row_indices = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
328
329
let mut left_idx = IdxCa::default();
330
let mut right_idx = IdxCa::default();
331
for (l, r) in row_indices.into_iter().flatten() {
332
left_idx.append(&l)?;
333
right_idx.append(&r)?;
334
}
335
if let Some((offset, end)) = slice {
336
left_idx = left_idx.slice(offset, end);
337
right_idx = right_idx.slice(offset, end);
338
}
339
340
unsafe { materialize_join(left, right, &left_idx, &right_idx, suffix) }
341
}
342
343
pub(super) fn iejoin(
344
left: &DataFrame,
345
right: &DataFrame,
346
selected_left: Vec<Series>,
347
selected_right: Vec<Series>,
348
options: &IEJoinOptions,
349
suffix: Option<PlSmallStr>,
350
slice: Option<(i64, usize)>,
351
) -> PolarsResult<DataFrame> {
352
let (left_row_idx, right_row_idx) = if options.operator2.is_some() {
353
iejoin_tuples(selected_left, selected_right, options, slice)
354
} else {
355
piecewise_merge_join_tuples(selected_left, selected_right, options, slice)
356
}?;
357
unsafe { materialize_join(left, right, &left_row_idx, &right_row_idx, suffix) }
358
}
359
360
unsafe fn materialize_join(
361
left: &DataFrame,
362
right: &DataFrame,
363
left_row_idx: &IdxCa,
364
right_row_idx: &IdxCa,
365
suffix: Option<PlSmallStr>,
366
) -> PolarsResult<DataFrame> {
367
try_raise_keyboard_interrupt();
368
let (join_left, join_right) = {
369
POOL.join(
370
|| left.take_unchecked(left_row_idx),
371
|| right.take_unchecked(right_row_idx),
372
)
373
};
374
375
_finish_join(join_left, join_right, suffix)
376
}
377
378
/// Inequality join. Matches rows between two DataFrames using two inequality operators
379
/// (one of [<, <=, >, >=]).
380
/// Based on Khayyat et al. 2015, "Lightning Fast and Space Efficient Inequality Joins"
381
/// and extended to work with duplicate values.
382
fn iejoin_tuples(
383
selected_left: Vec<Series>,
384
selected_right: Vec<Series>,
385
options: &IEJoinOptions,
386
slice: Option<(i64, usize)>,
387
) -> PolarsResult<(IdxCa, IdxCa)> {
388
if selected_left.len() != 2 {
389
return Err(
390
polars_err!(ComputeError: "IEJoin requires exactly two expressions from the left DataFrame"),
391
);
392
};
393
if selected_right.len() != 2 {
394
return Err(
395
polars_err!(ComputeError: "IEJoin requires exactly two expressions from the right DataFrame"),
396
);
397
};
398
399
let op1 = options.operator1;
400
let op2 = match options.operator2 {
401
None => {
402
return Err(polars_err!(ComputeError: "IEJoin requires two inequality operators"));
403
},
404
Some(op2) => op2,
405
};
406
407
// Determine the sort order based on the comparison operators used.
408
// We want to sort L1 so that "x[i] op1 x[j]" is true for j > i,
409
// and L2 so that "y[i] op2 y[j]" is true for j < i
410
// (except in the case of duplicates and strict inequalities).
411
// Note that the algorithms published in Khayyat et al. have incorrect logic for
412
// determining whether to sort descending.
413
let l1_descending = matches!(op1, InequalityOperator::Gt | InequalityOperator::GtEq);
414
let l2_descending = matches!(op2, InequalityOperator::Lt | InequalityOperator::LtEq);
415
416
let mut x = selected_left[0].to_physical_repr().into_owned();
417
let left_height = x.len();
418
419
x.extend(&selected_right[0].to_physical_repr())?;
420
// Rechunk because we will gather.
421
let x = x.rechunk();
422
423
let mut y = selected_left[1].to_physical_repr().into_owned();
424
y.extend(&selected_right[1].to_physical_repr())?;
425
// Rechunk because we will gather.
426
let y = y.rechunk();
427
428
let l1_sort_options = SortOptions::default()
429
.with_maintain_order(true)
430
.with_nulls_last(false)
431
.with_order_descending(l1_descending);
432
// Get ordering of x, skipping any null entries as these cannot be matches
433
let l1_order = x
434
.arg_sort(l1_sort_options)
435
.slice(x.null_count() as i64, x.len() - x.null_count());
436
437
let y_ordered_by_x = unsafe { y.take_unchecked(&l1_order) };
438
let l2_sort_options = SortOptions::default()
439
.with_maintain_order(true)
440
.with_nulls_last(false)
441
.with_order_descending(l2_descending);
442
// Get the indexes into l1, ordered by y values.
443
// l2_order is the same as "p" from Khayyat et al.
444
let l2_order = y_ordered_by_x.arg_sort(l2_sort_options).slice(
445
y_ordered_by_x.null_count() as i64,
446
y_ordered_by_x.len() - y_ordered_by_x.null_count(),
447
);
448
let l2_order = l2_order.rechunk();
449
let l2_order = l2_order.downcast_as_array().values().as_slice();
450
451
let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(x.dtype(), |$T| {
452
ie_join_impl_t::<$T>(
453
slice,
454
l1_order,
455
l2_order,
456
op1,
457
op2,
458
x,
459
y_ordered_by_x,
460
left_height
461
)
462
})?;
463
464
debug_assert_eq!(left_row_idx.len(), right_row_idx.len());
465
let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);
466
let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);
467
let (left_row_idx, right_row_idx) = match slice {
468
None => (left_row_idx, right_row_idx),
469
Some((offset, len)) => (
470
left_row_idx.slice(offset, len),
471
right_row_idx.slice(offset, len),
472
),
473
};
474
Ok((left_row_idx, right_row_idx))
475
}
476
477
/// Piecewise merge join, for joins with only a single inequality.
478
fn piecewise_merge_join_tuples(
479
selected_left: Vec<Series>,
480
selected_right: Vec<Series>,
481
options: &IEJoinOptions,
482
slice: Option<(i64, usize)>,
483
) -> PolarsResult<(IdxCa, IdxCa)> {
484
if selected_left.len() != 1 {
485
return Err(
486
polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the left DataFrame"),
487
);
488
};
489
if selected_right.len() != 1 {
490
return Err(
491
polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the right DataFrame"),
492
);
493
};
494
if options.operator2.is_some() {
495
return Err(
496
polars_err!(ComputeError: "Piecewise merge join expects only one inequality operator"),
497
);
498
}
499
500
let op = options.operator1;
501
// The left side is sorted such that if the condition is false, it will also
502
// be false for the same RHS row and all following LHS rows.
503
// The right side is sorted such that if the condition is true then it is also
504
// true for the same LHS row and all following RHS rows.
505
// The desired sort order should match the l1 order used in iejoin_par
506
// so we don't need to re-sort slices when doing a parallel join.
507
let descending = matches!(op, InequalityOperator::Gt | InequalityOperator::GtEq);
508
509
let left = selected_left[0].to_physical_repr().into_owned();
510
let mut right = selected_right[0].to_physical_repr().into_owned();
511
let must_cast = right.dtype().matches_schema_type(left.dtype())?;
512
if must_cast {
513
right = right.cast(left.dtype())?;
514
}
515
516
fn get_sorted(series: Series, descending: bool) -> (Series, Option<IdxCa>) {
517
let expected_flag = if descending {
518
IsSorted::Descending
519
} else {
520
IsSorted::Ascending
521
};
522
if (series.is_sorted_flag() == expected_flag || series.len() <= 1) && !series.has_nulls() {
523
// Fast path, no need to re-sort
524
(series, None)
525
} else {
526
let sort_options = SortOptions::default()
527
.with_nulls_last(false)
528
.with_order_descending(descending);
529
530
// Get order and slice to ignore any null values, which cannot be match results
531
let mut order = series.arg_sort(sort_options).slice(
532
series.null_count() as i64,
533
series.len() - series.null_count(),
534
);
535
order.rechunk_mut();
536
let ordered = unsafe { series.take_unchecked(&order) };
537
(ordered, Some(order))
538
}
539
}
540
541
let (left_ordered, left_order) = get_sorted(left, descending);
542
debug_assert!(
543
left_order
544
.as_ref()
545
.is_none_or(|order| order.chunks().len() == 1)
546
);
547
let left_order = left_order
548
.as_ref()
549
.map(|order| order.downcast_get(0).unwrap().values().as_slice());
550
551
let (right_ordered, right_order) = get_sorted(right, descending);
552
debug_assert!(
553
right_order
554
.as_ref()
555
.is_none_or(|order| order.chunks().len() == 1)
556
);
557
let right_order = right_order
558
.as_ref()
559
.map(|order| order.downcast_get(0).unwrap().values().as_slice());
560
561
let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(left_ordered.dtype(), |$T| {
562
match op {
563
InequalityOperator::Lt => piecewise_merge_join_impl_t::<$T, _>(
564
slice,
565
left_order,
566
right_order,
567
left_ordered,
568
right_ordered,
569
|l, r| l.tot_lt(r),
570
),
571
InequalityOperator::LtEq => piecewise_merge_join_impl_t::<$T, _>(
572
slice,
573
left_order,
574
right_order,
575
left_ordered,
576
right_ordered,
577
|l, r| l.tot_le(r),
578
),
579
InequalityOperator::Gt => piecewise_merge_join_impl_t::<$T, _>(
580
slice,
581
left_order,
582
right_order,
583
left_ordered,
584
right_ordered,
585
|l, r| l.tot_gt(r),
586
),
587
InequalityOperator::GtEq => piecewise_merge_join_impl_t::<$T, _>(
588
slice,
589
left_order,
590
right_order,
591
left_ordered,
592
right_ordered,
593
|l, r| l.tot_ge(r),
594
),
595
}
596
})?;
597
598
debug_assert_eq!(left_row_idx.len(), right_row_idx.len());
599
let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);
600
let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);
601
let (left_row_idx, right_row_idx) = match slice {
602
None => (left_row_idx, right_row_idx),
603
Some((offset, len)) => (
604
left_row_idx.slice(offset, len),
605
right_row_idx.slice(offset, len),
606
),
607
};
608
Ok((left_row_idx, right_row_idx))
609
}
610
611
fn slice_end_index(slice: Option<(i64, usize)>) -> Option<i64> {
612
match slice {
613
Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)),
614
_ => None,
615
}
616
}
617
618