Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/asof/mod.rs
8458 views
1
mod default;
2
mod groups;
3
use std::borrow::Cow;
4
use std::cmp::Ordering;
5
6
use default::*;
7
pub use groups::AsofJoinBy;
8
use polars_core::prelude::*;
9
use polars_utils::pl_str::PlSmallStr;
10
use polars_utils::total_ord::TotalOrd;
11
#[cfg(feature = "serde")]
12
use serde::{Deserialize, Serialize};
13
14
use super::{_finish_join, build_tables};
15
use crate::frame::IntoDf;
16
use crate::series::SeriesMethods;
17
18
#[inline]
19
fn ge_allow_eq<T: TotalOrd>(l: &T, r: &T, allow_eq: bool) -> bool {
20
match l.tot_cmp(r) {
21
Ordering::Equal => allow_eq,
22
Ordering::Greater => true,
23
Ordering::Less => false,
24
}
25
}
26
27
#[inline]
28
fn lt_allow_eq<T: TotalOrd>(l: &T, r: &T, allow_eq: bool) -> bool {
29
match l.tot_cmp(r) {
30
Ordering::Equal => allow_eq,
31
Ordering::Less => true,
32
Ordering::Greater => false,
33
}
34
}
35
36
trait AsofJoinState<T> {
37
fn next<F: FnMut(IdxSize) -> Option<T>>(
38
&mut self,
39
left_val: &T,
40
right: F,
41
n_right: IdxSize,
42
) -> Option<IdxSize>;
43
44
fn new(allow_eq: bool) -> Self;
45
}
46
47
struct AsofJoinForwardState {
48
scan_offset: IdxSize,
49
allow_eq: bool,
50
}
51
52
impl<T: TotalOrd> AsofJoinState<T> for AsofJoinForwardState {
53
fn new(allow_eq: bool) -> Self {
54
AsofJoinForwardState {
55
scan_offset: Default::default(),
56
allow_eq,
57
}
58
}
59
#[inline]
60
fn next<F: FnMut(IdxSize) -> Option<T>>(
61
&mut self,
62
left_val: &T,
63
mut right: F,
64
n_right: IdxSize,
65
) -> Option<IdxSize> {
66
while (self.scan_offset) < n_right {
67
if let Some(right_val) = right(self.scan_offset) {
68
if ge_allow_eq(&right_val, left_val, self.allow_eq) {
69
return Some(self.scan_offset);
70
}
71
}
72
self.scan_offset += 1;
73
}
74
None
75
}
76
}
77
78
struct AsofJoinBackwardState {
79
// best_bound is the greatest right index <= left_val.
80
best_bound: Option<IdxSize>,
81
scan_offset: IdxSize,
82
allow_eq: bool,
83
}
84
85
impl<T: TotalOrd> AsofJoinState<T> for AsofJoinBackwardState {
86
fn new(allow_eq: bool) -> Self {
87
AsofJoinBackwardState {
88
scan_offset: Default::default(),
89
best_bound: Default::default(),
90
allow_eq,
91
}
92
}
93
#[inline]
94
fn next<F: FnMut(IdxSize) -> Option<T>>(
95
&mut self,
96
left_val: &T,
97
mut right: F,
98
n_right: IdxSize,
99
) -> Option<IdxSize> {
100
while self.scan_offset < n_right {
101
if let Some(right_val) = right(self.scan_offset) {
102
if lt_allow_eq(&right_val, left_val, self.allow_eq) {
103
self.best_bound = Some(self.scan_offset);
104
} else {
105
break;
106
}
107
}
108
self.scan_offset += 1;
109
}
110
self.best_bound
111
}
112
}
113
114
#[derive(Default)]
115
struct AsofJoinNearestState {
116
/// The last value that is strictly smaller than the current
117
/// left value.
118
strictly_smaller: Option<IdxSize>,
119
/// If `allow_eq == false`: the first value strictly greater than the
120
/// current left value.
121
/// If `allow_eq == true`: the last value of the first chunk of equal
122
/// values that are strictly greater than the current left value.
123
upper_candidate: IdxSize,
124
allow_eq: bool,
125
}
126
127
impl<T: NumericNative> AsofJoinState<T> for AsofJoinNearestState {
128
fn new(allow_eq: bool) -> Self {
129
AsofJoinNearestState {
130
allow_eq,
131
..Default::default()
132
}
133
}
134
#[inline]
135
fn next<F: FnMut(IdxSize) -> Option<T>>(
136
&mut self,
137
left_val: &T,
138
mut right: F,
139
n_right: IdxSize,
140
) -> Option<IdxSize> {
141
// Skipping ahead to the first value greater than left_val. This is
142
// cheaper than computing differences.
143
while self.upper_candidate < n_right {
144
let Some(scan_right_val) = right(self.upper_candidate) else {
145
self.upper_candidate += 1;
146
continue;
147
};
148
if scan_right_val > *left_val {
149
break;
150
}
151
self.upper_candidate += 1;
152
}
153
154
if self.allow_eq
155
&& self.upper_candidate > 0
156
&& right(self.upper_candidate - 1) == Some(*left_val)
157
{
158
return Some(self.upper_candidate - 1);
159
}
160
161
// It is possible there are later elements equal to our
162
// scan, so keep going on.
163
while self.upper_candidate + 1 < n_right
164
&& right(self.upper_candidate + 1) == right(self.upper_candidate)
165
{
166
self.upper_candidate += 1;
167
}
168
169
let mut cursor = self.strictly_smaller.unwrap_or(0);
170
while cursor < self.upper_candidate {
171
let Some(scan_right_val) = right(cursor) else {
172
cursor += 1;
173
continue;
174
};
175
if scan_right_val >= *left_val {
176
break;
177
}
178
self.strictly_smaller = Some(cursor);
179
cursor += 1;
180
}
181
182
let mut right_get = |idx: IdxSize| (idx < n_right).then(|| right(idx)).flatten();
183
let lower = self.strictly_smaller.and_then(&mut right_get);
184
let upper = right_get(self.upper_candidate);
185
match (lower, upper) {
186
(None, None) => None,
187
(Some(_), None) => self.strictly_smaller,
188
(None, Some(_)) => Some(self.upper_candidate),
189
(Some(lo), Some(hi)) => {
190
let lo_diff = left_val.abs_diff(lo);
191
let hi_diff = left_val.abs_diff(hi);
192
if hi_diff <= lo_diff {
193
Some(self.upper_candidate)
194
} else {
195
self.strictly_smaller
196
}
197
},
198
}
199
}
200
}
201
202
#[derive(Clone, Debug, PartialEq, Default, Hash)]
203
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
204
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
205
pub struct AsOfOptions {
206
pub strategy: AsofStrategy,
207
/// A tolerance in the same unit as the asof column
208
pub tolerance: Option<Scalar>,
209
/// A time duration specified as a string, for example:
210
/// - "5m"
211
/// - "2h15m"
212
/// - "1d6h"
213
pub tolerance_str: Option<PlSmallStr>,
214
pub left_by: Option<Vec<PlSmallStr>>,
215
pub right_by: Option<Vec<PlSmallStr>>,
216
/// Allow equal matches
217
pub allow_eq: bool,
218
pub check_sortedness: bool,
219
}
220
221
fn check_asof_columns(
222
a: &Series,
223
b: &Series,
224
has_tolerance: bool,
225
check_sortedness: bool,
226
by_groups_present: bool,
227
) -> PolarsResult<()> {
228
let dtype_a = a.dtype();
229
let dtype_b = b.dtype();
230
if has_tolerance {
231
polars_ensure!(
232
dtype_a.to_physical().is_primitive_numeric() && dtype_b.to_physical().is_primitive_numeric(),
233
InvalidOperation:
234
"asof join with tolerance is only supported on numeric/temporal keys"
235
);
236
} else {
237
polars_ensure!(
238
dtype_a.to_physical().is_primitive() && dtype_b.to_physical().is_primitive(),
239
InvalidOperation:
240
"asof join is only supported on primitive key types"
241
);
242
}
243
polars_ensure!(
244
dtype_a == dtype_b,
245
ComputeError: "mismatching key dtypes in asof-join: `{}` and `{}`",
246
a.dtype(), b.dtype()
247
);
248
if check_sortedness {
249
if by_groups_present {
250
polars_warn!("Sortedness of columns cannot be checked when 'by' groups provided");
251
} else {
252
a.ensure_sorted_arg("asof_join")?;
253
b.ensure_sorted_arg("asof_join")?;
254
}
255
}
256
Ok(())
257
}
258
259
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
260
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
261
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
262
pub enum AsofStrategy {
263
/// selects the last row in the right DataFrame whose ‘on’ key is less than or equal to the left’s key
264
#[default]
265
Backward,
266
/// selects the first row in the right DataFrame whose ‘on’ key is greater than or equal to the left’s key.
267
Forward,
268
/// selects the right in the right DataFrame whose 'on' key is nearest to the left's key.
269
Nearest,
270
}
271
272
pub trait AsofJoin: IntoDf {
273
#[doc(hidden)]
274
#[allow(clippy::too_many_arguments)]
275
fn _join_asof(
276
&self,
277
other: &DataFrame,
278
left_key: &Series,
279
right_key: &Series,
280
strategy: AsofStrategy,
281
tolerance: Option<AnyValue<'static>>,
282
suffix: Option<PlSmallStr>,
283
slice: Option<(i64, usize)>,
284
coalesce: bool,
285
allow_eq: bool,
286
check_sortedness: bool,
287
) -> PolarsResult<DataFrame> {
288
let self_df = self.to_df();
289
290
check_asof_columns(
291
left_key,
292
right_key,
293
tolerance.is_some(),
294
check_sortedness,
295
false,
296
)?;
297
let left_key = left_key.to_physical_repr();
298
let right_key = right_key.to_physical_repr();
299
300
let mut take_idx = match left_key.dtype() {
301
#[cfg(feature = "dtype-i128")]
302
DataType::Int128 => {
303
let ca = left_key.i128().unwrap();
304
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
305
},
306
DataType::Int64 => {
307
let ca = left_key.i64().unwrap();
308
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
309
},
310
DataType::Int32 => {
311
let ca = left_key.i32().unwrap();
312
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
313
},
314
#[cfg(feature = "dtype-u128")]
315
DataType::UInt128 => {
316
let ca = left_key.u128().unwrap();
317
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
318
},
319
DataType::UInt64 => {
320
let ca = left_key.u64().unwrap();
321
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
322
},
323
DataType::UInt32 => {
324
let ca = left_key.u32().unwrap();
325
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
326
},
327
#[cfg(feature = "dtype-f16")]
328
DataType::Float16 => {
329
let ca = left_key.f16().unwrap();
330
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
331
},
332
DataType::Float32 => {
333
let ca = left_key.f32().unwrap();
334
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
335
},
336
DataType::Float64 => {
337
let ca = left_key.f64().unwrap();
338
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
339
},
340
DataType::Boolean => {
341
let ca = left_key.bool().unwrap();
342
join_asof::<BooleanType>(ca, &right_key, strategy, allow_eq)
343
},
344
DataType::Binary => {
345
let ca = left_key.binary().unwrap();
346
join_asof::<BinaryType>(ca, &right_key, strategy, allow_eq)
347
},
348
DataType::String => {
349
let ca = left_key.str().unwrap();
350
let right_binary = right_key.cast(&DataType::Binary).unwrap();
351
join_asof::<BinaryType>(&ca.as_binary(), &right_binary, strategy, allow_eq)
352
},
353
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 => {
354
let left_key = left_key.cast(&DataType::Int32).unwrap();
355
let right_key = right_key.cast(&DataType::Int32).unwrap();
356
let ca = left_key.i32().unwrap();
357
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
358
},
359
dt => polars_bail!(opq = asof_join, dt),
360
}?;
361
try_raise_keyboard_interrupt();
362
363
// Drop right join column.
364
let other = if coalesce && left_key.name() == right_key.name() {
365
Cow::Owned(other.drop(right_key.name())?)
366
} else {
367
Cow::Borrowed(other)
368
};
369
370
let mut left = self_df.clone();
371
if let Some((offset, len)) = slice {
372
left = left.slice(offset, len);
373
take_idx = take_idx.slice(offset, len);
374
}
375
376
// SAFETY: join tuples are in bounds.
377
let right_df = unsafe { other.take_unchecked(&take_idx) };
378
379
_finish_join(left, right_df, suffix)
380
}
381
}
382
383
impl AsofJoin for DataFrame {}
384
385