Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/merge_join.rs
8480 views
1
use std::borrow::Cow;
2
use std::cmp::Ordering;
3
use std::iter::repeat_n;
4
5
use arrow::array::Array;
6
use arrow::array::builder::ShareStrategy;
7
use polars_core::frame::builder::DataFrameBuilder;
8
use polars_core::prelude::*;
9
use polars_core::with_match_physical_numeric_polars_type;
10
use polars_utils::itertools::Itertools;
11
use polars_utils::total_ord::TotalOrd;
12
use polars_utils::{IdxSize, format_pl_smallstr};
13
14
use crate::frame::{JoinArgs, JoinType};
15
use crate::series::coalesce_columns;
16
17
#[allow(clippy::too_many_arguments)]
18
pub fn match_keys(
19
build_keys: &Series,
20
probe_keys: &Series,
21
gather_build: &mut Vec<IdxSize>,
22
gather_probe: &mut Vec<IdxSize>,
23
gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
24
build_emit_unmatched: bool,
25
descending: bool,
26
nulls_equal: bool,
27
limit_results: usize,
28
build_row_offset: &mut usize,
29
probe_row_offset: &mut usize,
30
probe_last_matched: &mut usize,
31
) {
32
macro_rules! dispatch {
33
($build_keys_ca:expr) => {
34
match_keys_impl(
35
$build_keys_ca,
36
probe_keys.as_ref().as_ref(),
37
gather_build,
38
gather_probe,
39
gather_probe_unmatched,
40
build_emit_unmatched,
41
descending,
42
nulls_equal,
43
limit_results,
44
build_row_offset,
45
probe_row_offset,
46
probe_last_matched,
47
)
48
};
49
}
50
51
assert_eq!(build_keys.dtype(), probe_keys.dtype());
52
match build_keys.dtype() {
53
dt if dt.is_primitive_numeric() => {
54
with_match_physical_numeric_polars_type!(dt, |$T| {
55
type PhysCa = ChunkedArray<$T>;
56
let build_keys_ca: &PhysCa = build_keys.as_ref().as_ref();
57
dispatch!(build_keys_ca)
58
})
59
},
60
DataType::Boolean => dispatch!(build_keys.bool().unwrap()),
61
DataType::String => dispatch!(build_keys.str().unwrap()),
62
DataType::Binary => dispatch!(build_keys.binary().unwrap()),
63
DataType::BinaryOffset => dispatch!(build_keys.binary_offset().unwrap()),
64
#[cfg(feature = "dtype-categorical")]
65
DataType::Enum(cats, _) => with_match_categorical_physical_type!(cats.physical(), |$C| {
66
type PhysCa = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>;
67
let build_keys_ca: &PhysCa = build_keys.as_ref().as_ref();
68
dispatch!(build_keys_ca)
69
}),
70
DataType::Null => match_null_keys_impl(
71
build_keys.len(),
72
probe_keys.len(),
73
gather_build,
74
gather_probe,
75
gather_probe_unmatched,
76
build_emit_unmatched,
77
descending,
78
nulls_equal,
79
limit_results,
80
build_row_offset,
81
probe_row_offset,
82
probe_last_matched,
83
),
84
dt => unimplemented!("merge-join kernel not implemented for {:?}", dt),
85
}
86
}
87
88
#[allow(clippy::mut_range_bound, clippy::too_many_arguments)]
89
fn match_keys_impl<'a, T: PolarsDataType>(
90
build_keys: &'a ChunkedArray<T>,
91
probe_keys: &'a ChunkedArray<T>,
92
gather_build: &mut Vec<IdxSize>,
93
gather_probe: &mut Vec<IdxSize>,
94
mut gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
95
build_emit_unmatched: bool,
96
descending: bool,
97
nulls_equal: bool,
98
limit_results: usize,
99
build_row_offset: &mut usize,
100
probe_row_offset: &mut usize,
101
probe_first_unmatched: &mut usize,
102
) where
103
T::Physical<'a>: TotalOrd,
104
{
105
assert!(gather_build.is_empty());
106
assert!(gather_probe.is_empty());
107
108
let build_key = build_keys.downcast_as_array();
109
let probe_key = probe_keys.downcast_as_array();
110
111
while *build_row_offset < build_key.len() {
112
if gather_build.len() >= limit_results {
113
return;
114
}
115
116
let build_keyval = unsafe { build_key.get_unchecked(*build_row_offset) };
117
let build_keyval = build_keyval.as_ref();
118
let mut build_keyval_matched = false;
119
120
if nulls_equal || build_keyval.is_some() {
121
for probe_idx in *probe_row_offset..probe_key.len() {
122
let probe_keyval = unsafe { probe_key.get_unchecked(probe_idx) };
123
let probe_keyval = probe_keyval.as_ref();
124
125
let mut ord: Ordering = match (&build_keyval, &probe_keyval) {
126
(None, None) if nulls_equal => Ordering::Equal,
127
(Some(l), Some(r)) => TotalOrd::tot_cmp(*l, *r),
128
_ => continue,
129
};
130
if descending {
131
ord = ord.reverse();
132
}
133
134
match ord {
135
Ordering::Equal => {
136
if let Some(probe_unmatched) = gather_probe_unmatched.as_mut() {
137
// All probe keys up to and *excluding* this matched key are unmatched
138
probe_unmatched
139
.extend(*probe_first_unmatched as IdxSize..probe_idx as IdxSize);
140
*probe_first_unmatched = (*probe_first_unmatched).max(probe_idx + 1);
141
}
142
gather_build.push(*build_row_offset as IdxSize);
143
gather_probe.push(probe_idx as IdxSize);
144
build_keyval_matched = true;
145
},
146
Ordering::Greater => {
147
if let Some(probe_unmatched) = gather_probe_unmatched.as_mut() {
148
// All probe keys up to and *including* this matched key are unmatched
149
probe_unmatched
150
.extend(*probe_first_unmatched as IdxSize..=probe_idx as IdxSize);
151
*probe_first_unmatched = (*probe_first_unmatched).max(probe_idx + 1);
152
}
153
*probe_row_offset = probe_idx + 1;
154
},
155
Ordering::Less => {
156
break;
157
},
158
}
159
}
160
}
161
if build_emit_unmatched && !build_keyval_matched {
162
gather_build.push(*build_row_offset as IdxSize);
163
gather_probe.push(IdxSize::MAX);
164
}
165
*build_row_offset += 1;
166
}
167
if let Some(probe_unmatched) = gather_probe_unmatched {
168
probe_unmatched.extend(*probe_first_unmatched as IdxSize..probe_key.len() as IdxSize);
169
*probe_first_unmatched = probe_key.len();
170
}
171
*probe_row_offset = probe_key.len();
172
}
173
174
#[allow(clippy::mut_range_bound, clippy::too_many_arguments)]
175
fn match_null_keys_impl(
176
build_n: usize,
177
probe_n: usize,
178
gather_build: &mut Vec<IdxSize>,
179
gather_probe: &mut Vec<IdxSize>,
180
gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
181
build_emit_unmatched: bool,
182
_descending: bool,
183
nulls_equal: bool,
184
limit_results: usize,
185
build_row_offset: &mut usize,
186
probe_row_offset: &mut usize,
187
probe_last_matched: &mut usize,
188
) {
189
assert!(gather_build.is_empty());
190
assert!(gather_probe.is_empty());
191
192
if nulls_equal {
193
// All keys will match all other keys, so just emit the Cartesian product
194
while *build_row_offset < build_n {
195
if gather_build.len() >= limit_results {
196
return;
197
}
198
for probe_idx in *probe_row_offset..probe_n {
199
gather_build.push(*build_row_offset as IdxSize);
200
gather_probe.push(probe_idx as IdxSize);
201
}
202
*build_row_offset += 1;
203
}
204
} else {
205
// No keys can ever match, so just emit all build keys into gather_build
206
// and all probe keys into gather_probe_unmatched.
207
if build_emit_unmatched {
208
gather_build.extend(0..build_n as IdxSize);
209
gather_probe.extend(repeat_n(IdxSize::MAX, build_n));
210
}
211
if let Some(probe_unmatched) = gather_probe_unmatched {
212
probe_unmatched.extend(*probe_last_matched as IdxSize..probe_n as IdxSize);
213
*probe_last_matched = probe_n;
214
}
215
}
216
*build_row_offset = build_n;
217
*probe_row_offset = probe_n;
218
}
219
220
#[allow(clippy::too_many_arguments)]
221
pub fn gather_and_postprocess(
222
build: DataFrame,
223
probe: DataFrame,
224
gather_build: Option<&[IdxSize]>,
225
gather_probe: Option<&[IdxSize]>,
226
df_builders: &mut Option<(DataFrameBuilder, DataFrameBuilder)>,
227
args: &JoinArgs,
228
left_on: &[PlSmallStr],
229
right_on: &[PlSmallStr],
230
left_is_build: bool,
231
output_schema: &Schema,
232
) -> PolarsResult<DataFrame> {
233
let should_coalesce = args.should_coalesce();
234
let left_emit_unmatched = matches!(args.how, JoinType::Left | JoinType::Full);
235
let right_emit_unmatched = matches!(args.how, JoinType::Right | JoinType::Full);
236
237
let (mut left, mut right);
238
let (gather_left, gather_right);
239
if left_is_build {
240
(left, right) = (build, probe);
241
(gather_left, gather_right) = (gather_build, gather_probe);
242
} else {
243
(left, right) = (probe, build);
244
(gather_left, gather_right) = (gather_probe, gather_build);
245
}
246
247
// Remove non-payload columns
248
for col in left
249
.columns()
250
.iter()
251
.map(Column::name)
252
.cloned()
253
.collect_vec()
254
{
255
if left_on.contains(&col) && should_coalesce {
256
continue;
257
}
258
if !output_schema.contains(&col) {
259
left.drop_in_place(&col).unwrap();
260
}
261
}
262
for col in right
263
.columns()
264
.iter()
265
.map(Column::name)
266
.cloned()
267
.collect_vec()
268
{
269
if left_on.contains(&col) && should_coalesce {
270
continue;
271
}
272
let renamed = match left.schema().contains(&col) {
273
true => Cow::Owned(format_pl_smallstr!("{}{}", col, args.suffix())),
274
false => Cow::Borrowed(&col),
275
};
276
if !output_schema.contains(&renamed) {
277
right.drop_in_place(&col).unwrap();
278
}
279
}
280
281
if df_builders.is_none() {
282
*df_builders = Some((
283
DataFrameBuilder::new(left.schema().clone()),
284
DataFrameBuilder::new(right.schema().clone()),
285
));
286
}
287
288
let (left_build, right_build) = df_builders.as_mut().unwrap();
289
let mut left = match gather_left {
290
Some(gather_left) if right_emit_unmatched => {
291
left_build.opt_gather_extend(&left, gather_left, ShareStrategy::Never);
292
left_build.freeze_reset()
293
},
294
Some(gather_left) => unsafe {
295
left_build.gather_extend(&left, gather_left, ShareStrategy::Never);
296
left_build.freeze_reset()
297
},
298
None => DataFrame::full_null(left.schema(), gather_right.unwrap().len()),
299
};
300
let mut right = match gather_right {
301
Some(gather_right) if left_emit_unmatched => {
302
right_build.opt_gather_extend(&right, gather_right, ShareStrategy::Never);
303
right_build.freeze_reset()
304
},
305
Some(gather_right) => unsafe {
306
right_build.gather_extend(&right, gather_right, ShareStrategy::Never);
307
right_build.freeze_reset()
308
},
309
None => DataFrame::full_null(right.schema(), gather_left.unwrap().len()),
310
};
311
312
// Coalsesce the key columns
313
if args.how == JoinType::Left && should_coalesce {
314
for c in left_on {
315
if right.schema().contains(c) {
316
right.drop_in_place(c.as_str())?;
317
}
318
}
319
} else if args.how == JoinType::Right && should_coalesce {
320
for c in right_on {
321
if left.schema().contains(c) {
322
left.drop_in_place(c.as_str())?;
323
}
324
}
325
}
326
327
// Rename any right columns to "{}_right"
328
let left_cols: PlHashSet<_> = left.columns().iter().map(Column::name).cloned().collect();
329
let right_cols_vec = right.get_column_names_owned();
330
let renames = right_cols_vec
331
.iter()
332
.filter(|c| left_cols.contains(*c))
333
.map(|c| {
334
let renamed = format_pl_smallstr!("{}{}", c, args.suffix());
335
(c.as_str(), renamed)
336
});
337
right.rename_many(renames).unwrap();
338
339
left.hstack_mut(right.columns())?;
340
341
if args.how == JoinType::Full && should_coalesce {
342
// Coalesce key columns
343
for (left_keycol, right_keycol) in Iterator::zip(left_on.iter(), right_on.iter()) {
344
let right_keycol = format_pl_smallstr!("{}{}", right_keycol, args.suffix());
345
let left_col = left.column(left_keycol).unwrap();
346
let right_col = left.column(&right_keycol).unwrap();
347
let coalesced = coalesce_columns(&[left_col.clone(), right_col.clone()]).unwrap();
348
left.replace(left_keycol, coalesced)
349
.unwrap()
350
.drop_in_place(&right_keycol)
351
.unwrap();
352
}
353
}
354
355
if should_coalesce {
356
for col in left_on {
357
if left.schema().contains(col) && !output_schema.contains(col) {
358
left.drop_in_place(col).unwrap();
359
}
360
}
361
for col in right_on {
362
let renamed = match left.schema().contains(col) {
363
true => Cow::Owned(format_pl_smallstr!("{}{}", col, args.suffix())),
364
false => Cow::Borrowed(col),
365
};
366
if left.schema().contains(&renamed) && !output_schema.contains(&renamed) {
367
left.drop_in_place(&renamed).unwrap();
368
}
369
}
370
}
371
372
debug_assert_eq!(**left.schema(), *output_schema);
373
Ok(left)
374
}
375
376