Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/asof/groups.rs
8446 views
1
use std::hash::Hash;
2
3
use num_traits::Zero;
4
use polars_core::hashing::_HASHMAP_INIT_SIZE;
5
use polars_core::prelude::*;
6
use polars_core::series::BitRepr;
7
use polars_core::utils::flatten::flatten_nullable;
8
use polars_core::utils::split_and_flatten;
9
use polars_core::{POOL, with_match_physical_float_polars_type};
10
use polars_utils::abs_diff::AbsDiff;
11
use polars_utils::hashing::{DirtyHash, hash_to_partition};
12
use polars_utils::nulls::IsNull;
13
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
14
use rayon::prelude::*;
15
16
use super::*;
17
use crate::frame::join::{prepare_binary, prepare_keys_multiple};
18
19
fn compute_len_offsets<I: IntoIterator<Item = usize>>(iter: I) -> Vec<usize> {
20
let mut cumlen = 0;
21
iter.into_iter()
22
.map(|l| {
23
let offset = cumlen;
24
cumlen += l;
25
offset
26
})
27
.collect()
28
}
29
30
#[inline(always)]
31
fn materialize_nullable(idx: Option<IdxSize>) -> NullableIdxSize {
32
match idx {
33
Some(t) => NullableIdxSize::from(t),
34
None => NullableIdxSize::null(),
35
}
36
}
37
38
fn asof_in_group<'a, T, A, F>(
39
left_val: T::Physical<'a>,
40
right_val_arr: &'a T::Array,
41
right_grp_idxs: &[IdxSize],
42
group_states: &mut PlHashMap<IdxSize, A>,
43
filter: F,
44
allow_eq: bool,
45
) -> Option<IdxSize>
46
where
47
T: PolarsDataType,
48
A: AsofJoinState<T::Physical<'a>>,
49
F: Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
50
{
51
// We use the index of the first element in a group as an identifier to
52
// associate with the group state.
53
let id = right_grp_idxs.first()?;
54
let grp_state = group_states.entry(*id).or_insert_with(|| A::new(allow_eq));
55
56
unsafe {
57
let r_grp_idx = grp_state.next(
58
&left_val,
59
|i| {
60
// SAFETY: the group indices are valid, and next() only calls with
61
// i < right_grp_idxs.len().
62
right_val_arr.get_unchecked(*right_grp_idxs.get_unchecked(i as usize) as usize)
63
},
64
right_grp_idxs.len() as IdxSize,
65
)?;
66
67
// SAFETY: r_grp_idx is valid, as is r_idx (which must be non-null) if
68
// we get here.
69
let r_idx = *right_grp_idxs.get_unchecked(r_grp_idx as usize);
70
let right_val = right_val_arr.value_unchecked(r_idx as usize);
71
filter(left_val, right_val).then_some(r_idx)
72
}
73
}
74
75
fn asof_join_by_numeric<T, S, A, F>(
76
by_left: &ChunkedArray<S>,
77
by_right: &ChunkedArray<S>,
78
left_asof: &ChunkedArray<T>,
79
right_asof: &ChunkedArray<T>,
80
filter: F,
81
allow_eq: bool,
82
) -> PolarsResult<IdxArr>
83
where
84
T: PolarsDataType,
85
S: PolarsNumericType,
86
S::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
87
<S::Native as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
88
A: for<'a> AsofJoinState<T::Physical<'a>>,
89
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
90
{
91
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
92
let left_val_arr = left_asof.downcast_as_array();
93
let right_val_arr = right_asof.downcast_as_array();
94
95
let n_threads = POOL.current_num_threads();
96
// `strict` is false so that we always flatten. Even if there are more chunks than threads.
97
let split_by_left = split_and_flatten(by_left, n_threads);
98
let split_by_right = split_and_flatten(by_right, n_threads);
99
let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len()));
100
101
// TODO: handle nulls more efficiently. Right now we just join on the value
102
// ignoring the validity mask, and ignore the nulls later.
103
let right_slices = split_by_right
104
.iter()
105
.map(|ca| {
106
assert_eq!(ca.chunks().len(), 1);
107
ca.downcast_iter().next().unwrap().values_iter().copied()
108
})
109
.collect();
110
let hash_tbls = build_tables(right_slices, false);
111
let n_tables = hash_tbls.len();
112
113
// Now we probe the right hand side for each left hand side.
114
let out = split_by_left
115
.into_par_iter()
116
.zip(offsets)
117
.map(|(by_left, offset)| {
118
let mut results = Vec::with_capacity(by_left.len());
119
let mut group_states: PlHashMap<IdxSize, A> =
120
PlHashMap::with_capacity(_HASHMAP_INIT_SIZE);
121
122
assert_eq!(by_left.chunks().len(), 1);
123
let by_left_chunk = by_left.downcast_iter().next().unwrap();
124
for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() {
125
let Some(by_left_k) = opt_by_left_k else {
126
results.push(NullableIdxSize::null());
127
continue;
128
};
129
let by_left_k = by_left_k.to_total_ord();
130
let idx_left = (rel_idx_left + offset) as IdxSize;
131
let Some(left_val) = left_val_arr.get(idx_left as usize) else {
132
results.push(NullableIdxSize::null());
133
continue;
134
};
135
136
let group_probe_table = unsafe {
137
hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables))
138
};
139
let Some(right_grp_idxs) = group_probe_table.get(&by_left_k) else {
140
results.push(NullableIdxSize::null());
141
continue;
142
};
143
let id = asof_in_group::<T, A, &F>(
144
left_val,
145
right_val_arr,
146
right_grp_idxs.as_slice(),
147
&mut group_states,
148
&filter,
149
allow_eq,
150
);
151
results.push(materialize_nullable(id));
152
}
153
results
154
});
155
156
let bufs = POOL.install(|| out.collect::<Vec<_>>());
157
Ok(flatten_nullable(&bufs))
158
}
159
160
fn asof_join_by_binary<B, T, A, F>(
161
by_left: &ChunkedArray<B>,
162
by_right: &ChunkedArray<B>,
163
left_asof: &ChunkedArray<T>,
164
right_asof: &ChunkedArray<T>,
165
filter: F,
166
allow_eq: bool,
167
) -> IdxArr
168
where
169
B: PolarsDataType,
170
for<'b> <B::Array as StaticArray>::ValueT<'b>: AsRef<[u8]>,
171
T: PolarsDataType,
172
A: for<'a> AsofJoinState<T::Physical<'a>>,
173
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
174
{
175
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
176
let left_val_arr = left_asof.downcast_as_array();
177
let right_val_arr = right_asof.downcast_as_array();
178
179
let (prep_by_left, prep_by_right, _, _) = prepare_binary::<B>(by_left, by_right, false);
180
let offsets = compute_len_offsets(prep_by_left.iter().map(|s| s.len()));
181
let hash_tbls = build_tables(prep_by_right, false);
182
let n_tables = hash_tbls.len();
183
184
// Now we probe the right hand side for each left hand side.
185
let iter = prep_by_left
186
.into_par_iter()
187
.zip(offsets)
188
.map(|(by_left, offset)| {
189
let mut results = Vec::with_capacity(by_left.len());
190
let mut group_states: PlHashMap<_, A> = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE);
191
192
for (rel_idx_left, by_left_k) in by_left.iter().enumerate() {
193
let idx_left = (rel_idx_left + offset) as IdxSize;
194
let Some(left_val) = left_val_arr.get(idx_left as usize) else {
195
results.push(NullableIdxSize::null());
196
continue;
197
};
198
199
let group_probe_table = unsafe {
200
hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables))
201
};
202
let Some(right_grp_idxs) = group_probe_table.get(by_left_k) else {
203
results.push(NullableIdxSize::null());
204
continue;
205
};
206
let id = asof_in_group::<T, A, &F>(
207
left_val,
208
right_val_arr,
209
right_grp_idxs.as_slice(),
210
&mut group_states,
211
&filter,
212
allow_eq,
213
);
214
215
results.push(materialize_nullable(id));
216
}
217
results
218
});
219
let bufs = POOL.install(|| iter.collect::<Vec<_>>());
220
flatten_nullable(&bufs)
221
}
222
223
#[allow(clippy::too_many_arguments)]
224
fn dispatch_join_by_type<T, A, F>(
225
left_asof: &ChunkedArray<T>,
226
right_asof: &ChunkedArray<T>,
227
left_by: &mut DataFrame,
228
right_by: &mut DataFrame,
229
filter: F,
230
allow_eq: bool,
231
) -> PolarsResult<IdxArr>
232
where
233
T: PolarsDataType,
234
A: for<'a> AsofJoinState<T::Physical<'a>>,
235
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
236
{
237
let out = if left_by.width() == 1 {
238
let left_by_s = left_by.columns()[0].to_physical_repr();
239
let right_by_s = right_by.columns()[0].to_physical_repr();
240
let left_dtype = left_by_s.dtype();
241
let right_dtype = right_by_s.dtype();
242
polars_ensure!(left_dtype == right_dtype,
243
ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{left_dtype}` and `{right_dtype}`",
244
);
245
match left_dtype {
246
DataType::String => {
247
let left_by = &left_by_s.str().unwrap().as_binary();
248
let right_by = right_by_s.str().unwrap().as_binary();
249
asof_join_by_binary::<BinaryType, T, A, F>(
250
left_by, &right_by, left_asof, right_asof, filter, allow_eq,
251
)
252
},
253
DataType::Binary => {
254
let left_by = &left_by_s.binary().unwrap();
255
let right_by = right_by_s.binary().unwrap();
256
asof_join_by_binary::<BinaryType, T, A, F>(
257
left_by, right_by, left_asof, right_asof, filter, allow_eq,
258
)
259
},
260
x if x.is_float() => {
261
with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| {
262
let left_by: &ChunkedArray<$T> = left_by_s.as_materialized_series().as_ref().as_ref().as_ref();
263
let right_by: &ChunkedArray<$T> = right_by_s.as_materialized_series().as_ref().as_ref().as_ref();
264
asof_join_by_numeric::<T, $T, A, F>(
265
left_by, right_by, left_asof, right_asof, filter, allow_eq
266
)?
267
})
268
},
269
_ => {
270
let left_by = left_by_s.bit_repr();
271
let right_by = right_by_s.bit_repr();
272
273
let (Some(left_by), Some(right_by)) = (left_by, right_by) else {
274
polars_bail!(nyi = "Dispatch join for {left_dtype} and {right_dtype}");
275
};
276
277
use BitRepr as B;
278
match (left_by, right_by) {
279
(B::U8(left_by), B::U8(right_by)) => {
280
asof_join_by_numeric::<T, UInt8Type, A, F>(
281
&left_by, &right_by, left_asof, right_asof, filter, allow_eq,
282
)?
283
},
284
(B::U16(left_by), B::U16(right_by)) => {
285
asof_join_by_numeric::<T, UInt16Type, A, F>(
286
&left_by, &right_by, left_asof, right_asof, filter, allow_eq,
287
)?
288
},
289
(B::U32(left_by), B::U32(right_by)) => {
290
asof_join_by_numeric::<T, UInt32Type, A, F>(
291
&left_by, &right_by, left_asof, right_asof, filter, allow_eq,
292
)?
293
},
294
(B::U64(left_by), B::U64(right_by)) => {
295
asof_join_by_numeric::<T, UInt64Type, A, F>(
296
&left_by, &right_by, left_asof, right_asof, filter, allow_eq,
297
)?
298
},
299
#[cfg(feature = "dtype-u128")]
300
(B::U128(left_by), B::U128(right_by)) => {
301
asof_join_by_numeric::<T, UInt128Type, A, F>(
302
&left_by, &right_by, left_asof, right_asof, filter, allow_eq,
303
)?
304
},
305
// We have already asserted that the datatypes are the same.
306
_ => unreachable!(),
307
}
308
},
309
}
310
} else {
311
for (lhs, rhs) in left_by.columns().iter().zip(right_by.columns()) {
312
polars_ensure!(lhs.dtype() == rhs.dtype(),
313
ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", lhs.dtype(), rhs.dtype()
314
);
315
}
316
317
// TODO: @scalar-opt.
318
let left_by_series: Vec<_> = left_by.materialized_column_iter().cloned().collect();
319
let right_by_series: Vec<_> = right_by.materialized_column_iter().cloned().collect();
320
let lhs_keys = prepare_keys_multiple(&left_by_series, false)?;
321
let rhs_keys = prepare_keys_multiple(&right_by_series, false)?;
322
asof_join_by_binary::<BinaryOffsetType, T, A, F>(
323
&lhs_keys, &rhs_keys, left_asof, right_asof, filter, allow_eq,
324
)
325
};
326
Ok(out)
327
}
328
329
#[allow(clippy::too_many_arguments)]
330
fn dispatch_join_strategy<T: PolarsDataType>(
331
left_asof: &ChunkedArray<T>,
332
right_asof: &Series,
333
left_by: &mut DataFrame,
334
right_by: &mut DataFrame,
335
strategy: AsofStrategy,
336
allow_eq: bool,
337
) -> PolarsResult<IdxArr>
338
where
339
for<'a> T::Physical<'a>: TotalOrd,
340
{
341
let right_asof = left_asof.unpack_series_matching_type(right_asof)?;
342
343
let filter = |_a: T::Physical<'_>, _b: T::Physical<'_>| true;
344
match strategy {
345
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
346
left_asof, right_asof, left_by, right_by, filter, allow_eq,
347
),
348
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
349
left_asof, right_asof, left_by, right_by, filter, allow_eq,
350
),
351
AsofStrategy::Nearest => unimplemented!(),
352
}
353
}
354
355
#[allow(clippy::too_many_arguments)]
356
fn dispatch_join_strategy_numeric<T: PolarsNumericType>(
357
left_asof: &ChunkedArray<T>,
358
right_asof: &Series,
359
left_by: &mut DataFrame,
360
right_by: &mut DataFrame,
361
strategy: AsofStrategy,
362
tolerance: Option<AnyValue<'static>>,
363
allow_eq: bool,
364
) -> PolarsResult<IdxArr> {
365
let right_ca = left_asof.unpack_series_matching_type(right_asof)?;
366
367
if let Some(tol) = tolerance {
368
let native_tolerance: T::Native = tol.try_extract()?;
369
let abs_tolerance = native_tolerance.abs_diff(T::Native::zero());
370
let filter = |a: T::Native, b: T::Native| a.abs_diff(b) <= abs_tolerance;
371
match strategy {
372
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
373
left_asof, right_ca, left_by, right_by, filter, allow_eq,
374
),
375
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
376
left_asof, right_ca, left_by, right_by, filter, allow_eq,
377
),
378
AsofStrategy::Nearest => dispatch_join_by_type::<T, AsofJoinNearestState, _>(
379
left_asof, right_ca, left_by, right_by, filter, allow_eq,
380
),
381
}
382
} else {
383
let filter = |_a: T::Physical<'_>, _b: T::Physical<'_>| true;
384
match strategy {
385
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
386
left_asof, right_ca, left_by, right_by, filter, allow_eq,
387
),
388
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
389
left_asof, right_ca, left_by, right_by, filter, allow_eq,
390
),
391
AsofStrategy::Nearest => dispatch_join_by_type::<T, AsofJoinNearestState, _>(
392
left_asof, right_ca, left_by, right_by, filter, allow_eq,
393
),
394
}
395
}
396
}
397
398
#[allow(clippy::too_many_arguments)]
399
fn dispatch_join_type(
400
left_asof: &Series,
401
right_asof: &Series,
402
left_by: &mut DataFrame,
403
right_by: &mut DataFrame,
404
strategy: AsofStrategy,
405
tolerance: Option<AnyValue<'static>>,
406
allow_eq: bool,
407
) -> PolarsResult<IdxArr> {
408
match left_asof.dtype() {
409
DataType::Int64 => {
410
let ca = left_asof.i64().unwrap();
411
dispatch_join_strategy_numeric(
412
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
413
)
414
},
415
DataType::Int32 => {
416
let ca = left_asof.i32().unwrap();
417
dispatch_join_strategy_numeric(
418
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
419
)
420
},
421
DataType::UInt64 => {
422
let ca = left_asof.u64().unwrap();
423
dispatch_join_strategy_numeric(
424
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
425
)
426
},
427
DataType::UInt32 => {
428
let ca = left_asof.u32().unwrap();
429
dispatch_join_strategy_numeric(
430
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
431
)
432
},
433
#[cfg(feature = "dtype-i128")]
434
DataType::Int128 => {
435
let ca = left_asof.i128().unwrap();
436
dispatch_join_strategy_numeric(
437
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
438
)
439
},
440
#[cfg(feature = "dtype-u128")]
441
DataType::UInt128 => {
442
let ca = left_asof.u128().unwrap();
443
dispatch_join_strategy_numeric(
444
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
445
)
446
},
447
#[cfg(feature = "dtype-f16")]
448
DataType::Float16 => {
449
let ca = left_asof.f16().unwrap();
450
dispatch_join_strategy_numeric(
451
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
452
)
453
},
454
DataType::Float32 => {
455
let ca = left_asof.f32().unwrap();
456
dispatch_join_strategy_numeric(
457
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
458
)
459
},
460
DataType::Float64 => {
461
let ca = left_asof.f64().unwrap();
462
dispatch_join_strategy_numeric(
463
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
464
)
465
},
466
DataType::Boolean => {
467
let ca = left_asof.bool().unwrap();
468
dispatch_join_strategy::<BooleanType>(
469
ca, right_asof, left_by, right_by, strategy, allow_eq,
470
)
471
},
472
DataType::Binary => {
473
let ca = left_asof.binary().unwrap();
474
dispatch_join_strategy::<BinaryType>(
475
ca, right_asof, left_by, right_by, strategy, allow_eq,
476
)
477
},
478
DataType::String => {
479
let ca = left_asof.str().unwrap();
480
let right_binary = right_asof.cast(&DataType::Binary).unwrap();
481
dispatch_join_strategy::<BinaryType>(
482
&ca.as_binary(),
483
&right_binary,
484
left_by,
485
right_by,
486
strategy,
487
allow_eq,
488
)
489
},
490
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 => {
491
let left_asof = left_asof.cast(&DataType::Int32).unwrap();
492
let right_asof = right_asof.cast(&DataType::Int32).unwrap();
493
let ca = left_asof.i32().unwrap();
494
dispatch_join_strategy_numeric(
495
ca,
496
&right_asof,
497
left_by,
498
right_by,
499
strategy,
500
tolerance,
501
allow_eq,
502
)
503
},
504
dt => polars_bail!(opq = asof_join, dt),
505
}
506
}
507
508
pub trait AsofJoinBy: IntoDf {
509
#[allow(clippy::too_many_arguments)]
510
#[doc(hidden)]
511
fn _join_asof_by(
512
&self,
513
other: &DataFrame,
514
left_on: &Series,
515
right_on: &Series,
516
left_by: Vec<PlSmallStr>,
517
right_by: Vec<PlSmallStr>,
518
strategy: AsofStrategy,
519
tolerance: Option<AnyValue<'static>>,
520
suffix: Option<PlSmallStr>,
521
slice: Option<(i64, usize)>,
522
coalesce: bool,
523
allow_eq: bool,
524
check_sortedness: bool,
525
) -> PolarsResult<DataFrame> {
526
let (self_sliced_slot, left_slice_s); // Keeps temporaries alive.
527
let (self_df, other_df, left_key, right_key);
528
if let Some((offset, len)) = slice {
529
self_sliced_slot = self.to_df().slice(offset, len);
530
left_slice_s = left_on.slice(offset, len);
531
left_key = &left_slice_s;
532
right_key = right_on;
533
self_df = &self_sliced_slot;
534
other_df = other;
535
} else {
536
self_df = self.to_df();
537
other_df = other;
538
left_key = left_on;
539
right_key = right_on;
540
}
541
542
let left_asof = left_key.to_physical_repr();
543
let right_asof = right_key.to_physical_repr();
544
let right_asof_name = right_asof.name();
545
let left_asof_name = left_asof.name();
546
check_asof_columns(
547
&left_asof,
548
&right_asof,
549
tolerance.is_some(),
550
check_sortedness,
551
!(left_by.is_empty() && right_by.is_empty()),
552
)?;
553
554
let mut left_by = self_df.select(left_by)?;
555
let mut right_by = other_df.select(right_by)?;
556
557
for (l, r) in unsafe { left_by.columns_mut() }
558
.iter_mut()
559
.zip(unsafe { right_by.columns_mut() }.iter_mut())
560
{
561
*l = l.to_physical_repr();
562
*r = r.to_physical_repr();
563
}
564
565
let right_join_tuples = dispatch_join_type(
566
&left_asof,
567
&right_asof,
568
&mut left_by,
569
&mut right_by,
570
strategy,
571
tolerance,
572
allow_eq,
573
)?;
574
575
let mut drop_these = right_by.get_column_names();
576
if coalesce && left_asof_name == right_asof_name {
577
drop_these.push(right_asof_name);
578
}
579
580
let cols = other_df
581
.columns()
582
.iter()
583
.filter(|s| !drop_these.contains(&s.name()))
584
.cloned()
585
.collect();
586
let proj_other_df = unsafe { DataFrame::new_unchecked(other_df.height(), cols) };
587
588
let left = self_df.clone();
589
590
// SAFETY: join tuples are in bounds.
591
let right_df = unsafe {
592
proj_other_df.take_unchecked(&IdxCa::with_chunk(PlSmallStr::EMPTY, right_join_tuples))
593
};
594
595
_finish_join(left, right_df, suffix)
596
}
597
598
/// This is similar to a left-join except that we match on nearest key
599
/// rather than equal keys. The keys must be sorted to perform an asof join.
600
/// This is a special implementation of an asof join that searches for the
601
/// nearest keys within a subgroup set by `by`.
602
#[allow(clippy::too_many_arguments)]
603
fn join_asof_by<I, S>(
604
&self,
605
other: &DataFrame,
606
left_on: &str,
607
right_on: &str,
608
left_by: I,
609
right_by: I,
610
strategy: AsofStrategy,
611
tolerance: Option<AnyValue<'static>>,
612
allow_eq: bool,
613
check_sortedness: bool,
614
) -> PolarsResult<DataFrame>
615
where
616
I: IntoIterator<Item = S>,
617
S: AsRef<str>,
618
{
619
let self_df = self.to_df();
620
let left_by = left_by.into_iter().map(|s| s.as_ref().into()).collect();
621
let right_by = right_by.into_iter().map(|s| s.as_ref().into()).collect();
622
let left_key = self_df.column(left_on)?.as_materialized_series();
623
let right_key = other.column(right_on)?.as_materialized_series();
624
self_df._join_asof_by(
625
other,
626
left_key,
627
right_key,
628
left_by,
629
right_by,
630
strategy,
631
tolerance,
632
None,
633
None,
634
true,
635
allow_eq,
636
check_sortedness,
637
)
638
}
639
}
640
641
impl AsofJoinBy for DataFrame {}
642
643
#[cfg(test)]
644
mod test {
645
use super::*;
646
647
#[test]
648
fn test_asof_by() -> PolarsResult<()> {
649
let a = df![
650
"a" => [-1, 2, 3, 3, 3, 4],
651
"b" => ["a", "b", "c", "d", "e", "f"]
652
]?;
653
654
let b = df![
655
"a" => [1, 2, 3, 3],
656
"b" => ["a", "b", "c", "d"],
657
"right_vals" => [1, 2, 3, 4]
658
]?;
659
660
let out = a.join_asof_by(
661
&b,
662
"a",
663
"a",
664
["b"],
665
["b"],
666
AsofStrategy::Backward,
667
None,
668
true,
669
true,
670
)?;
671
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
672
let out = out.column("right_vals").unwrap();
673
let out = out.i32().unwrap();
674
assert_eq!(
675
Vec::from(out),
676
&[None, Some(2), Some(3), Some(4), None, None]
677
);
678
Ok(())
679
}
680
681
#[test]
682
fn test_asof_by2() -> PolarsResult<()> {
683
let trades = df![
684
"time" => [23i64, 38, 48, 48, 48],
685
"ticker" => ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"],
686
"groups_numeric" => [1, 1, 2, 2, 3],
687
"bid" => [51.95, 51.95, 720.77, 720.92, 98.0]
688
]?;
689
690
let quotes = df![
691
"time" => [23i64,
692
23,
693
30,
694
41,
695
48,
696
49,
697
72,
698
75],
699
"ticker" => ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"],
700
"groups_numeric" => [2, 1, 1, 1, 2, 3, 2, 1],
701
"bid" => [720.5, 51.95, 51.97, 51.99, 720.5, 97.99, 720.5, 52.01]
702
703
]?;
704
705
let out = trades.join_asof_by(
706
&quotes,
707
"time",
708
"time",
709
["ticker"],
710
["ticker"],
711
AsofStrategy::Backward,
712
None,
713
true,
714
true,
715
)?;
716
let a = out.column("bid_right").unwrap();
717
let a = a.f64().unwrap();
718
let expected = &[Some(51.95), Some(51.97), Some(720.5), Some(720.5), None];
719
720
assert_eq!(Vec::from(a), expected);
721
722
let out = trades.join_asof_by(
723
&quotes,
724
"time",
725
"time",
726
["groups_numeric"],
727
["groups_numeric"],
728
AsofStrategy::Backward,
729
None,
730
true,
731
true,
732
)?;
733
let a = out.column("bid_right").unwrap();
734
let a = a.f64().unwrap();
735
736
assert_eq!(Vec::from(a), expected);
737
738
Ok(())
739
}
740
741
#[test]
742
fn test_asof_by3() -> PolarsResult<()> {
743
let a = df![
744
"a" => [ -1, 2, 2, 3, 3, 3, 4],
745
"b" => ["a", "a", "b", "c", "d", "e", "f"]
746
]?;
747
748
let b = df![
749
"a" => [ 1, 3, 2, 3, 2],
750
"b" => ["a", "a", "b", "c", "d"],
751
"right_vals" => [ 1, 3, 2, 3, 4]
752
]?;
753
754
let out = a.join_asof_by(
755
&b,
756
"a",
757
"a",
758
["b"],
759
["b"],
760
AsofStrategy::Forward,
761
None,
762
true,
763
true,
764
)?;
765
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
766
let out = out.column("right_vals").unwrap();
767
let out = out.i32().unwrap();
768
assert_eq!(
769
Vec::from(out),
770
&[Some(1), Some(3), Some(2), Some(3), None, None, None]
771
);
772
773
let out = a.join_asof_by(
774
&b,
775
"a",
776
"a",
777
["b"],
778
["b"],
779
AsofStrategy::Forward,
780
Some(AnyValue::Int32(1)),
781
true,
782
true,
783
)?;
784
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
785
let out = out.column("right_vals").unwrap();
786
let out = out.i32().unwrap();
787
assert_eq!(
788
Vec::from(out),
789
&[None, Some(3), Some(2), Some(3), None, None, None]
790
);
791
792
Ok(())
793
}
794
795
#[test]
796
fn test_asof_by4() -> PolarsResult<()> {
797
let trades = df![
798
"time" => [23i64, 38, 48, 48, 48],
799
"ticker" => ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"],
800
"groups_numeric" => [1, 1, 2, 2, 3],
801
"bid" => [51.95, 51.95, 720.77, 720.92, 98.0]
802
]?;
803
804
let quotes = df![
805
"time" => [23i64, 23, 30, 41, 48, 49, 72, 75],
806
"ticker" => ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"],
807
"bid" => [720.5, 51.95, 51.97, 51.99, 720.5, 97.99, 720.5, 52.01],
808
"groups_numeric" => [2, 1, 1, 1, 2, 3, 2, 1],
809
810
]?;
811
/*
812
trades:
813
shape: (5, 4)
814
┌──────┬────────┬────────────────┬────────┐
815
│ time ┆ ticker ┆ groups_numeric ┆ bid │
816
│ --- ┆ --- ┆ --- ┆ --- │
817
│ i64 ┆ str ┆ i32 ┆ f64 │
818
╞══════╪════════╪════════════════╪════════╡
819
│ 23 ┆ MSFT ┆ 1 ┆ 51.95 │
820
│ 38 ┆ MSFT ┆ 1 ┆ 51.95 │
821
│ 48 ┆ GOOG ┆ 2 ┆ 720.77 │
822
│ 48 ┆ GOOG ┆ 2 ┆ 720.92 │
823
│ 48 ┆ AAPL ┆ 3 ┆ 98.0 │
824
└──────┴────────┴────────────────┴────────┘
825
quotes:
826
shape: (8, 4)
827
┌──────┬────────┬───────┬────────────────┐
828
│ time ┆ ticker ┆ bid ┆ groups_numeric │
829
│ --- ┆ --- ┆ --- ┆ --- │
830
│ i64 ┆ str ┆ f64 ┆ i32 │
831
╞══════╪════════╪═══════╪════════════════╡
832
│ 23 ┆ GOOG ┆ 720.5 ┆ 2 │
833
│ 23 ┆ MSFT ┆ 51.95 ┆ 1 │
834
│ 30 ┆ MSFT ┆ 51.97 ┆ 1 │
835
│ 41 ┆ MSFT ┆ 51.99 ┆ 1 │
836
│ 48 ┆ GOOG ┆ 720.5 ┆ 2 │
837
│ 49 ┆ AAPL ┆ 97.99 ┆ 3 │
838
│ 72 ┆ GOOG ┆ 720.5 ┆ 2 │
839
│ 75 ┆ MSFT ┆ 52.01 ┆ 1 │
840
└──────┴────────┴───────┴────────────────┘
841
*/
842
843
let out = trades.join_asof_by(
844
&quotes,
845
"time",
846
"time",
847
["ticker"],
848
["ticker"],
849
AsofStrategy::Forward,
850
None,
851
true,
852
true,
853
)?;
854
let a = out.column("bid_right").unwrap();
855
let a = a.f64().unwrap();
856
let expected = &[
857
Some(51.95),
858
Some(51.99),
859
Some(720.5),
860
Some(720.5),
861
Some(97.99),
862
];
863
864
assert_eq!(Vec::from(a), expected);
865
866
let out = trades.join_asof_by(
867
&quotes,
868
"time",
869
"time",
870
["groups_numeric"],
871
["groups_numeric"],
872
AsofStrategy::Forward,
873
None,
874
true,
875
true,
876
)?;
877
let a = out.column("bid_right").unwrap();
878
let a = a.f64().unwrap();
879
880
assert_eq!(Vec::from(a), expected);
881
882
Ok(())
883
}
884
}
885
886