Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/asof/groups.rs
6940 views
1
use std::hash::Hash;
2
3
use num_traits::Zero;
4
use polars_core::hashing::_HASHMAP_INIT_SIZE;
5
use polars_core::prelude::*;
6
use polars_core::series::BitRepr;
7
use polars_core::utils::flatten::flatten_nullable;
8
use polars_core::utils::split_and_flatten;
9
use polars_core::{POOL, with_match_physical_float_polars_type};
10
use polars_utils::abs_diff::AbsDiff;
11
use polars_utils::hashing::{DirtyHash, hash_to_partition};
12
use polars_utils::nulls::IsNull;
13
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
14
use rayon::prelude::*;
15
16
use super::*;
17
use crate::frame::join::{prepare_binary, prepare_keys_multiple};
18
19
fn compute_len_offsets<I: IntoIterator<Item = usize>>(iter: I) -> Vec<usize> {
20
let mut cumlen = 0;
21
iter.into_iter()
22
.map(|l| {
23
let offset = cumlen;
24
cumlen += l;
25
offset
26
})
27
.collect()
28
}
29
30
#[inline(always)]
31
fn materialize_nullable(idx: Option<IdxSize>) -> NullableIdxSize {
32
match idx {
33
Some(t) => NullableIdxSize::from(t),
34
None => NullableIdxSize::null(),
35
}
36
}
37
38
fn asof_in_group<'a, T, A, F>(
39
left_val: T::Physical<'a>,
40
right_val_arr: &'a T::Array,
41
right_grp_idxs: &[IdxSize],
42
group_states: &mut PlHashMap<IdxSize, A>,
43
filter: F,
44
allow_eq: bool,
45
) -> Option<IdxSize>
46
where
47
T: PolarsDataType,
48
A: AsofJoinState<T::Physical<'a>>,
49
F: Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
50
{
51
// We use the index of the first element in a group as an identifier to
52
// associate with the group state.
53
let id = right_grp_idxs.first()?;
54
let grp_state = group_states.entry(*id).or_insert_with(|| A::new(allow_eq));
55
56
unsafe {
57
let r_grp_idx = grp_state.next(
58
&left_val,
59
|i| {
60
// SAFETY: the group indices are valid, and next() only calls with
61
// i < right_grp_idxs.len().
62
right_val_arr.get_unchecked(*right_grp_idxs.get_unchecked(i as usize) as usize)
63
},
64
right_grp_idxs.len() as IdxSize,
65
)?;
66
67
// SAFETY: r_grp_idx is valid, as is r_idx (which must be non-null) if
68
// we get here.
69
let r_idx = *right_grp_idxs.get_unchecked(r_grp_idx as usize);
70
let right_val = right_val_arr.value_unchecked(r_idx as usize);
71
filter(left_val, right_val).then_some(r_idx)
72
}
73
}
74
75
fn asof_join_by_numeric<T, S, A, F>(
76
by_left: &ChunkedArray<S>,
77
by_right: &ChunkedArray<S>,
78
left_asof: &ChunkedArray<T>,
79
right_asof: &ChunkedArray<T>,
80
filter: F,
81
allow_eq: bool,
82
) -> PolarsResult<IdxArr>
83
where
84
T: PolarsDataType,
85
S: PolarsNumericType,
86
S::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
87
<S::Native as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
88
A: for<'a> AsofJoinState<T::Physical<'a>>,
89
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
90
{
91
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
92
let left_val_arr = left_asof.downcast_as_array();
93
let right_val_arr = right_asof.downcast_as_array();
94
95
let n_threads = POOL.current_num_threads();
96
// `strict` is false so that we always flatten. Even if there are more chunks than threads.
97
let split_by_left = split_and_flatten(by_left, n_threads);
98
let split_by_right = split_and_flatten(by_right, n_threads);
99
let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len()));
100
101
// TODO: handle nulls more efficiently. Right now we just join on the value
102
// ignoring the validity mask, and ignore the nulls later.
103
let right_slices = split_by_right
104
.iter()
105
.map(|ca| {
106
assert_eq!(ca.chunks().len(), 1);
107
ca.downcast_iter().next().unwrap().values_iter().copied()
108
})
109
.collect();
110
let hash_tbls = build_tables(right_slices, false);
111
let n_tables = hash_tbls.len();
112
113
// Now we probe the right hand side for each left hand side.
114
let out = split_by_left
115
.into_par_iter()
116
.zip(offsets)
117
.map(|(by_left, offset)| {
118
let mut results = Vec::with_capacity(by_left.len());
119
let mut group_states: PlHashMap<IdxSize, A> =
120
PlHashMap::with_capacity(_HASHMAP_INIT_SIZE);
121
122
assert_eq!(by_left.chunks().len(), 1);
123
let by_left_chunk = by_left.downcast_iter().next().unwrap();
124
for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() {
125
let Some(by_left_k) = opt_by_left_k else {
126
results.push(NullableIdxSize::null());
127
continue;
128
};
129
let by_left_k = by_left_k.to_total_ord();
130
let idx_left = (rel_idx_left + offset) as IdxSize;
131
let Some(left_val) = left_val_arr.get(idx_left as usize) else {
132
results.push(NullableIdxSize::null());
133
continue;
134
};
135
136
let group_probe_table = unsafe {
137
hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables))
138
};
139
let Some(right_grp_idxs) = group_probe_table.get(&by_left_k) else {
140
results.push(NullableIdxSize::null());
141
continue;
142
};
143
let id = asof_in_group::<T, A, &F>(
144
left_val,
145
right_val_arr,
146
right_grp_idxs.as_slice(),
147
&mut group_states,
148
&filter,
149
allow_eq,
150
);
151
results.push(materialize_nullable(id));
152
}
153
results
154
});
155
156
let bufs = POOL.install(|| out.collect::<Vec<_>>());
157
Ok(flatten_nullable(&bufs))
158
}
159
160
fn asof_join_by_binary<B, T, A, F>(
161
by_left: &ChunkedArray<B>,
162
by_right: &ChunkedArray<B>,
163
left_asof: &ChunkedArray<T>,
164
right_asof: &ChunkedArray<T>,
165
filter: F,
166
allow_eq: bool,
167
) -> IdxArr
168
where
169
B: PolarsDataType,
170
for<'b> <B::Array as StaticArray>::ValueT<'b>: AsRef<[u8]>,
171
T: PolarsDataType,
172
A: for<'a> AsofJoinState<T::Physical<'a>>,
173
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
174
{
175
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
176
let left_val_arr = left_asof.downcast_as_array();
177
let right_val_arr = right_asof.downcast_as_array();
178
179
let (prep_by_left, prep_by_right, _, _) = prepare_binary::<B>(by_left, by_right, false);
180
let offsets = compute_len_offsets(prep_by_left.iter().map(|s| s.len()));
181
let hash_tbls = build_tables(prep_by_right, false);
182
let n_tables = hash_tbls.len();
183
184
// Now we probe the right hand side for each left hand side.
185
let iter = prep_by_left
186
.into_par_iter()
187
.zip(offsets)
188
.map(|(by_left, offset)| {
189
let mut results = Vec::with_capacity(by_left.len());
190
let mut group_states: PlHashMap<_, A> = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE);
191
192
for (rel_idx_left, by_left_k) in by_left.iter().enumerate() {
193
let idx_left = (rel_idx_left + offset) as IdxSize;
194
let Some(left_val) = left_val_arr.get(idx_left as usize) else {
195
results.push(NullableIdxSize::null());
196
continue;
197
};
198
199
let group_probe_table = unsafe {
200
hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables))
201
};
202
let Some(right_grp_idxs) = group_probe_table.get(by_left_k) else {
203
results.push(NullableIdxSize::null());
204
continue;
205
};
206
let id = asof_in_group::<T, A, &F>(
207
left_val,
208
right_val_arr,
209
right_grp_idxs.as_slice(),
210
&mut group_states,
211
&filter,
212
allow_eq,
213
);
214
215
results.push(materialize_nullable(id));
216
}
217
results
218
});
219
let bufs = POOL.install(|| iter.collect::<Vec<_>>());
220
flatten_nullable(&bufs)
221
}
222
223
#[allow(clippy::too_many_arguments)]
224
fn dispatch_join_by_type<T, A, F>(
225
left_asof: &ChunkedArray<T>,
226
right_asof: &ChunkedArray<T>,
227
left_by: &mut DataFrame,
228
right_by: &mut DataFrame,
229
filter: F,
230
allow_eq: bool,
231
) -> PolarsResult<IdxArr>
232
where
233
T: PolarsDataType,
234
A: for<'a> AsofJoinState<T::Physical<'a>>,
235
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
236
{
237
let out = if left_by.width() == 1 {
238
let left_by_s = left_by.get_columns()[0].to_physical_repr();
239
let right_by_s = right_by.get_columns()[0].to_physical_repr();
240
let left_dtype = left_by_s.dtype();
241
let right_dtype = right_by_s.dtype();
242
polars_ensure!(left_dtype == right_dtype,
243
ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{left_dtype}` and `{right_dtype}`",
244
);
245
match left_dtype {
246
DataType::String => {
247
let left_by = &left_by_s.str().unwrap().as_binary();
248
let right_by = right_by_s.str().unwrap().as_binary();
249
asof_join_by_binary::<BinaryType, T, A, F>(
250
left_by, &right_by, left_asof, right_asof, filter, allow_eq,
251
)
252
},
253
DataType::Binary => {
254
let left_by = &left_by_s.binary().unwrap();
255
let right_by = right_by_s.binary().unwrap();
256
asof_join_by_binary::<BinaryType, T, A, F>(
257
left_by, right_by, left_asof, right_asof, filter, allow_eq,
258
)
259
},
260
x if x.is_float() => {
261
with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| {
262
let left_by: &ChunkedArray<$T> = left_by_s.as_materialized_series().as_ref().as_ref().as_ref();
263
let right_by: &ChunkedArray<$T> = right_by_s.as_materialized_series().as_ref().as_ref().as_ref();
264
asof_join_by_numeric::<T, $T, A, F>(
265
left_by, right_by, left_asof, right_asof, filter, allow_eq
266
)?
267
})
268
},
269
_ => {
270
let left_by = left_by_s.bit_repr();
271
let right_by = right_by_s.bit_repr();
272
273
let (Some(left_by), Some(right_by)) = (left_by, right_by) else {
274
polars_bail!(nyi = "Dispatch join for {left_dtype} and {right_dtype}");
275
};
276
277
use BitRepr as B;
278
match (left_by, right_by) {
279
(B::U32(left_by), B::U32(right_by)) => {
280
asof_join_by_numeric::<T, UInt32Type, A, F>(
281
&left_by, &right_by, left_asof, right_asof, filter, allow_eq,
282
)?
283
},
284
(B::U64(left_by), B::U64(right_by)) => {
285
asof_join_by_numeric::<T, UInt64Type, A, F>(
286
&left_by, &right_by, left_asof, right_asof, filter, allow_eq,
287
)?
288
},
289
#[cfg(feature = "dtype-i128")]
290
(B::I128(left_by), B::I128(right_by)) => {
291
asof_join_by_numeric::<T, Int128Type, A, F>(
292
&left_by, &right_by, left_asof, right_asof, filter, allow_eq,
293
)?
294
},
295
// We have already asserted that the datatypes are the same.
296
_ => unreachable!(),
297
}
298
},
299
}
300
} else {
301
for (lhs, rhs) in left_by.get_columns().iter().zip(right_by.get_columns()) {
302
polars_ensure!(lhs.dtype() == rhs.dtype(),
303
ComputeError: "mismatching dtypes in 'by' parameter of asof-join: `{}` and `{}`", lhs.dtype(), rhs.dtype()
304
);
305
}
306
307
// TODO: @scalar-opt.
308
let left_by_series: Vec<_> = left_by.materialized_column_iter().cloned().collect();
309
let right_by_series: Vec<_> = right_by.materialized_column_iter().cloned().collect();
310
let lhs_keys = prepare_keys_multiple(&left_by_series, false)?;
311
let rhs_keys = prepare_keys_multiple(&right_by_series, false)?;
312
asof_join_by_binary::<BinaryOffsetType, T, A, F>(
313
&lhs_keys, &rhs_keys, left_asof, right_asof, filter, allow_eq,
314
)
315
};
316
Ok(out)
317
}
318
319
#[allow(clippy::too_many_arguments)]
320
fn dispatch_join_strategy<T: PolarsDataType>(
321
left_asof: &ChunkedArray<T>,
322
right_asof: &Series,
323
left_by: &mut DataFrame,
324
right_by: &mut DataFrame,
325
strategy: AsofStrategy,
326
allow_eq: bool,
327
) -> PolarsResult<IdxArr>
328
where
329
for<'a> T::Physical<'a>: PartialOrd,
330
{
331
let right_asof = left_asof.unpack_series_matching_type(right_asof)?;
332
333
let filter = |_a: T::Physical<'_>, _b: T::Physical<'_>| true;
334
match strategy {
335
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
336
left_asof, right_asof, left_by, right_by, filter, allow_eq,
337
),
338
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
339
left_asof, right_asof, left_by, right_by, filter, allow_eq,
340
),
341
AsofStrategy::Nearest => unimplemented!(),
342
}
343
}
344
345
#[allow(clippy::too_many_arguments)]
346
fn dispatch_join_strategy_numeric<T: PolarsNumericType>(
347
left_asof: &ChunkedArray<T>,
348
right_asof: &Series,
349
left_by: &mut DataFrame,
350
right_by: &mut DataFrame,
351
strategy: AsofStrategy,
352
tolerance: Option<AnyValue<'static>>,
353
allow_eq: bool,
354
) -> PolarsResult<IdxArr> {
355
let right_ca = left_asof.unpack_series_matching_type(right_asof)?;
356
357
if let Some(tol) = tolerance {
358
let native_tolerance: T::Native = tol.try_extract()?;
359
let abs_tolerance = native_tolerance.abs_diff(T::Native::zero());
360
let filter = |a: T::Native, b: T::Native| a.abs_diff(b) <= abs_tolerance;
361
match strategy {
362
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
363
left_asof, right_ca, left_by, right_by, filter, allow_eq,
364
),
365
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
366
left_asof, right_ca, left_by, right_by, filter, allow_eq,
367
),
368
AsofStrategy::Nearest => dispatch_join_by_type::<T, AsofJoinNearestState, _>(
369
left_asof, right_ca, left_by, right_by, filter, allow_eq,
370
),
371
}
372
} else {
373
let filter = |_a: T::Physical<'_>, _b: T::Physical<'_>| true;
374
match strategy {
375
AsofStrategy::Backward => dispatch_join_by_type::<T, AsofJoinBackwardState, _>(
376
left_asof, right_ca, left_by, right_by, filter, allow_eq,
377
),
378
AsofStrategy::Forward => dispatch_join_by_type::<T, AsofJoinForwardState, _>(
379
left_asof, right_ca, left_by, right_by, filter, allow_eq,
380
),
381
AsofStrategy::Nearest => dispatch_join_by_type::<T, AsofJoinNearestState, _>(
382
left_asof, right_ca, left_by, right_by, filter, allow_eq,
383
),
384
}
385
}
386
}
387
388
#[allow(clippy::too_many_arguments)]
389
fn dispatch_join_type(
390
left_asof: &Series,
391
right_asof: &Series,
392
left_by: &mut DataFrame,
393
right_by: &mut DataFrame,
394
strategy: AsofStrategy,
395
tolerance: Option<AnyValue<'static>>,
396
allow_eq: bool,
397
) -> PolarsResult<IdxArr> {
398
match left_asof.dtype() {
399
DataType::Int64 => {
400
let ca = left_asof.i64().unwrap();
401
dispatch_join_strategy_numeric(
402
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
403
)
404
},
405
DataType::Int32 => {
406
let ca = left_asof.i32().unwrap();
407
dispatch_join_strategy_numeric(
408
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
409
)
410
},
411
DataType::UInt64 => {
412
let ca = left_asof.u64().unwrap();
413
dispatch_join_strategy_numeric(
414
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
415
)
416
},
417
DataType::UInt32 => {
418
let ca = left_asof.u32().unwrap();
419
dispatch_join_strategy_numeric(
420
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
421
)
422
},
423
#[cfg(feature = "dtype-i128")]
424
DataType::Int128 => {
425
let ca = left_asof.i128().unwrap();
426
dispatch_join_strategy_numeric(
427
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
428
)
429
},
430
DataType::Float32 => {
431
let ca = left_asof.f32().unwrap();
432
dispatch_join_strategy_numeric(
433
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
434
)
435
},
436
DataType::Float64 => {
437
let ca = left_asof.f64().unwrap();
438
dispatch_join_strategy_numeric(
439
ca, right_asof, left_by, right_by, strategy, tolerance, allow_eq,
440
)
441
},
442
DataType::Boolean => {
443
let ca = left_asof.bool().unwrap();
444
dispatch_join_strategy::<BooleanType>(
445
ca, right_asof, left_by, right_by, strategy, allow_eq,
446
)
447
},
448
DataType::Binary => {
449
let ca = left_asof.binary().unwrap();
450
dispatch_join_strategy::<BinaryType>(
451
ca, right_asof, left_by, right_by, strategy, allow_eq,
452
)
453
},
454
DataType::String => {
455
let ca = left_asof.str().unwrap();
456
let right_binary = right_asof.cast(&DataType::Binary).unwrap();
457
dispatch_join_strategy::<BinaryType>(
458
&ca.as_binary(),
459
&right_binary,
460
left_by,
461
right_by,
462
strategy,
463
allow_eq,
464
)
465
},
466
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 => {
467
let left_asof = left_asof.cast(&DataType::Int32).unwrap();
468
let right_asof = right_asof.cast(&DataType::Int32).unwrap();
469
let ca = left_asof.i32().unwrap();
470
dispatch_join_strategy_numeric(
471
ca,
472
&right_asof,
473
left_by,
474
right_by,
475
strategy,
476
tolerance,
477
allow_eq,
478
)
479
},
480
dt => polars_bail!(opq = asof_join, dt),
481
}
482
}
483
484
pub trait AsofJoinBy: IntoDf {
485
#[allow(clippy::too_many_arguments)]
486
#[doc(hidden)]
487
fn _join_asof_by(
488
&self,
489
other: &DataFrame,
490
left_on: &Series,
491
right_on: &Series,
492
left_by: Vec<PlSmallStr>,
493
right_by: Vec<PlSmallStr>,
494
strategy: AsofStrategy,
495
tolerance: Option<AnyValue<'static>>,
496
suffix: Option<PlSmallStr>,
497
slice: Option<(i64, usize)>,
498
coalesce: bool,
499
allow_eq: bool,
500
check_sortedness: bool,
501
) -> PolarsResult<DataFrame> {
502
let (self_sliced_slot, left_slice_s); // Keeps temporaries alive.
503
let (self_df, other_df, left_key, right_key);
504
if let Some((offset, len)) = slice {
505
self_sliced_slot = self.to_df().slice(offset, len);
506
left_slice_s = left_on.slice(offset, len);
507
left_key = &left_slice_s;
508
right_key = right_on;
509
self_df = &self_sliced_slot;
510
other_df = other;
511
} else {
512
self_df = self.to_df();
513
other_df = other;
514
left_key = left_on;
515
right_key = right_on;
516
}
517
518
let left_asof = left_key.to_physical_repr();
519
let right_asof = right_key.to_physical_repr();
520
let right_asof_name = right_asof.name();
521
let left_asof_name = left_asof.name();
522
check_asof_columns(
523
&left_asof,
524
&right_asof,
525
tolerance.is_some(),
526
check_sortedness,
527
!(left_by.is_empty() && right_by.is_empty()),
528
)?;
529
530
let mut left_by = self_df.select(left_by)?;
531
let mut right_by = other_df.select(right_by)?;
532
533
unsafe {
534
for (l, r) in left_by
535
.get_columns_mut()
536
.iter_mut()
537
.zip(right_by.get_columns_mut().iter_mut())
538
{
539
*l = l.to_physical_repr();
540
*r = r.to_physical_repr();
541
}
542
}
543
544
let right_join_tuples = dispatch_join_type(
545
&left_asof,
546
&right_asof,
547
&mut left_by,
548
&mut right_by,
549
strategy,
550
tolerance,
551
allow_eq,
552
)?;
553
554
let mut drop_these = right_by.get_column_names();
555
if coalesce && left_asof_name == right_asof_name {
556
drop_these.push(right_asof_name);
557
}
558
559
let cols = other_df
560
.get_columns()
561
.iter()
562
.filter(|s| !drop_these.contains(&s.name()))
563
.cloned()
564
.collect();
565
let proj_other_df = unsafe { DataFrame::new_no_checks(other_df.height(), cols) };
566
567
let left = self_df.clone();
568
569
// SAFETY: join tuples are in bounds.
570
let right_df = unsafe {
571
proj_other_df.take_unchecked(&IdxCa::with_chunk(PlSmallStr::EMPTY, right_join_tuples))
572
};
573
574
_finish_join(left, right_df, suffix)
575
}
576
577
/// This is similar to a left-join except that we match on nearest key
578
/// rather than equal keys. The keys must be sorted to perform an asof join.
579
/// This is a special implementation of an asof join that searches for the
580
/// nearest keys within a subgroup set by `by`.
581
#[allow(clippy::too_many_arguments)]
582
fn join_asof_by<I, S>(
583
&self,
584
other: &DataFrame,
585
left_on: &str,
586
right_on: &str,
587
left_by: I,
588
right_by: I,
589
strategy: AsofStrategy,
590
tolerance: Option<AnyValue<'static>>,
591
allow_eq: bool,
592
check_sortedness: bool,
593
) -> PolarsResult<DataFrame>
594
where
595
I: IntoIterator<Item = S>,
596
S: AsRef<str>,
597
{
598
let self_df = self.to_df();
599
let left_by = left_by.into_iter().map(|s| s.as_ref().into()).collect();
600
let right_by = right_by.into_iter().map(|s| s.as_ref().into()).collect();
601
let left_key = self_df.column(left_on)?.as_materialized_series();
602
let right_key = other.column(right_on)?.as_materialized_series();
603
self_df._join_asof_by(
604
other,
605
left_key,
606
right_key,
607
left_by,
608
right_by,
609
strategy,
610
tolerance,
611
None,
612
None,
613
true,
614
allow_eq,
615
check_sortedness,
616
)
617
}
618
}
619
620
impl AsofJoinBy for DataFrame {}
621
622
#[cfg(test)]
623
mod test {
624
use super::*;
625
626
#[test]
627
fn test_asof_by() -> PolarsResult<()> {
628
let a = df![
629
"a" => [-1, 2, 3, 3, 3, 4],
630
"b" => ["a", "b", "c", "d", "e", "f"]
631
]?;
632
633
let b = df![
634
"a" => [1, 2, 3, 3],
635
"b" => ["a", "b", "c", "d"],
636
"right_vals" => [1, 2, 3, 4]
637
]?;
638
639
let out = a.join_asof_by(
640
&b,
641
"a",
642
"a",
643
["b"],
644
["b"],
645
AsofStrategy::Backward,
646
None,
647
true,
648
true,
649
)?;
650
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
651
let out = out.column("right_vals").unwrap();
652
let out = out.i32().unwrap();
653
assert_eq!(
654
Vec::from(out),
655
&[None, Some(2), Some(3), Some(4), None, None]
656
);
657
Ok(())
658
}
659
660
#[test]
661
fn test_asof_by2() -> PolarsResult<()> {
662
let trades = df![
663
"time" => [23i64, 38, 48, 48, 48],
664
"ticker" => ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"],
665
"groups_numeric" => [1, 1, 2, 2, 3],
666
"bid" => [51.95, 51.95, 720.77, 720.92, 98.0]
667
]?;
668
669
let quotes = df![
670
"time" => [23i64,
671
23,
672
30,
673
41,
674
48,
675
49,
676
72,
677
75],
678
"ticker" => ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"],
679
"groups_numeric" => [2, 1, 1, 1, 2, 3, 2, 1],
680
"bid" => [720.5, 51.95, 51.97, 51.99, 720.5, 97.99, 720.5, 52.01]
681
682
]?;
683
684
let out = trades.join_asof_by(
685
&quotes,
686
"time",
687
"time",
688
["ticker"],
689
["ticker"],
690
AsofStrategy::Backward,
691
None,
692
true,
693
true,
694
)?;
695
let a = out.column("bid_right").unwrap();
696
let a = a.f64().unwrap();
697
let expected = &[Some(51.95), Some(51.97), Some(720.5), Some(720.5), None];
698
699
assert_eq!(Vec::from(a), expected);
700
701
let out = trades.join_asof_by(
702
&quotes,
703
"time",
704
"time",
705
["groups_numeric"],
706
["groups_numeric"],
707
AsofStrategy::Backward,
708
None,
709
true,
710
true,
711
)?;
712
let a = out.column("bid_right").unwrap();
713
let a = a.f64().unwrap();
714
715
assert_eq!(Vec::from(a), expected);
716
717
Ok(())
718
}
719
720
#[test]
721
fn test_asof_by3() -> PolarsResult<()> {
722
let a = df![
723
"a" => [ -1, 2, 2, 3, 3, 3, 4],
724
"b" => ["a", "a", "b", "c", "d", "e", "f"]
725
]?;
726
727
let b = df![
728
"a" => [ 1, 3, 2, 3, 2],
729
"b" => ["a", "a", "b", "c", "d"],
730
"right_vals" => [ 1, 3, 2, 3, 4]
731
]?;
732
733
let out = a.join_asof_by(
734
&b,
735
"a",
736
"a",
737
["b"],
738
["b"],
739
AsofStrategy::Forward,
740
None,
741
true,
742
true,
743
)?;
744
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
745
let out = out.column("right_vals").unwrap();
746
let out = out.i32().unwrap();
747
assert_eq!(
748
Vec::from(out),
749
&[Some(1), Some(3), Some(2), Some(3), None, None, None]
750
);
751
752
let out = a.join_asof_by(
753
&b,
754
"a",
755
"a",
756
["b"],
757
["b"],
758
AsofStrategy::Forward,
759
Some(AnyValue::Int32(1)),
760
true,
761
true,
762
)?;
763
assert_eq!(out.get_column_names(), &["a", "b", "right_vals"]);
764
let out = out.column("right_vals").unwrap();
765
let out = out.i32().unwrap();
766
assert_eq!(
767
Vec::from(out),
768
&[None, Some(3), Some(2), Some(3), None, None, None]
769
);
770
771
Ok(())
772
}
773
774
#[test]
775
fn test_asof_by4() -> PolarsResult<()> {
776
let trades = df![
777
"time" => [23i64, 38, 48, 48, 48],
778
"ticker" => ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"],
779
"groups_numeric" => [1, 1, 2, 2, 3],
780
"bid" => [51.95, 51.95, 720.77, 720.92, 98.0]
781
]?;
782
783
let quotes = df![
784
"time" => [23i64, 23, 30, 41, 48, 49, 72, 75],
785
"ticker" => ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"],
786
"bid" => [720.5, 51.95, 51.97, 51.99, 720.5, 97.99, 720.5, 52.01],
787
"groups_numeric" => [2, 1, 1, 1, 2, 3, 2, 1],
788
789
]?;
790
/*
791
trades:
792
shape: (5, 4)
793
┌──────┬────────┬────────────────┬────────┐
794
│ time ┆ ticker ┆ groups_numeric ┆ bid │
795
│ --- ┆ --- ┆ --- ┆ --- │
796
│ i64 ┆ str ┆ i32 ┆ f64 │
797
╞══════╪════════╪════════════════╪════════╡
798
│ 23 ┆ MSFT ┆ 1 ┆ 51.95 │
799
│ 38 ┆ MSFT ┆ 1 ┆ 51.95 │
800
│ 48 ┆ GOOG ┆ 2 ┆ 720.77 │
801
│ 48 ┆ GOOG ┆ 2 ┆ 720.92 │
802
│ 48 ┆ AAPL ┆ 3 ┆ 98.0 │
803
└──────┴────────┴────────────────┴────────┘
804
quotes:
805
shape: (8, 4)
806
┌──────┬────────┬───────┬────────────────┐
807
│ time ┆ ticker ┆ bid ┆ groups_numeric │
808
│ --- ┆ --- ┆ --- ┆ --- │
809
│ i64 ┆ str ┆ f64 ┆ i32 │
810
╞══════╪════════╪═══════╪════════════════╡
811
│ 23 ┆ GOOG ┆ 720.5 ┆ 2 │
812
│ 23 ┆ MSFT ┆ 51.95 ┆ 1 │
813
│ 30 ┆ MSFT ┆ 51.97 ┆ 1 │
814
│ 41 ┆ MSFT ┆ 51.99 ┆ 1 │
815
│ 48 ┆ GOOG ┆ 720.5 ┆ 2 │
816
│ 49 ┆ AAPL ┆ 97.99 ┆ 3 │
817
│ 72 ┆ GOOG ┆ 720.5 ┆ 2 │
818
│ 75 ┆ MSFT ┆ 52.01 ┆ 1 │
819
└──────┴────────┴───────┴────────────────┘
820
*/
821
822
let out = trades.join_asof_by(
823
&quotes,
824
"time",
825
"time",
826
["ticker"],
827
["ticker"],
828
AsofStrategy::Forward,
829
None,
830
true,
831
true,
832
)?;
833
let a = out.column("bid_right").unwrap();
834
let a = a.f64().unwrap();
835
let expected = &[
836
Some(51.95),
837
Some(51.99),
838
Some(720.5),
839
Some(720.5),
840
Some(97.99),
841
];
842
843
assert_eq!(Vec::from(a), expected);
844
845
let out = trades.join_asof_by(
846
&quotes,
847
"time",
848
"time",
849
["groups_numeric"],
850
["groups_numeric"],
851
AsofStrategy::Forward,
852
None,
853
true,
854
true,
855
)?;
856
let a = out.column("bid_right").unwrap();
857
let a = a.f64().unwrap();
858
859
assert_eq!(Vec::from(a), expected);
860
861
Ok(())
862
}
863
}
864
865