Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/asof/default.rs
8446 views
1
use arrow::array::Array;
2
use arrow::bitmap::Bitmap;
3
use num_traits::Zero;
4
use polars_core::prelude::*;
5
use polars_utils::abs_diff::AbsDiff;
6
use polars_utils::total_ord::TotalOrd;
7
8
use super::{
9
AsofJoinBackwardState, AsofJoinForwardState, AsofJoinNearestState, AsofJoinState, AsofStrategy,
10
};
11
12
fn join_asof_impl<'a, T, S, F>(
13
left: &'a T::Array,
14
right: &'a T::Array,
15
mut filter: F,
16
allow_eq: bool,
17
) -> IdxCa
18
where
19
T: PolarsDataType,
20
S: AsofJoinState<T::Physical<'a>>,
21
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
22
{
23
if left.len() == left.null_count() || right.len() == right.null_count() {
24
return IdxCa::full_null(PlSmallStr::EMPTY, left.len());
25
}
26
27
let mut out = vec![0; left.len()];
28
let mut mask = vec![0; left.len().div_ceil(8)];
29
let mut state = S::new(allow_eq);
30
31
if left.null_count() == 0 && right.null_count() == 0 {
32
for (i, val_l) in left.values_iter().enumerate() {
33
if let Some(r_idx) = state.next(
34
&val_l,
35
// SAFETY: next() only calls with indices < right.len().
36
|j| Some(unsafe { right.value_unchecked(j as usize) }),
37
right.len() as IdxSize,
38
) {
39
// SAFETY: r_idx is non-null and valid.
40
unsafe {
41
let val_r = right.value_unchecked(r_idx as usize);
42
*out.get_unchecked_mut(i) = r_idx;
43
*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);
44
}
45
}
46
}
47
} else {
48
for (i, opt_val_l) in left.iter().enumerate() {
49
if let Some(val_l) = opt_val_l {
50
if let Some(r_idx) = state.next(
51
&val_l,
52
// SAFETY: next() only calls with indices < right.len().
53
|j| unsafe { right.get_unchecked(j as usize) },
54
right.len() as IdxSize,
55
) {
56
// SAFETY: r_idx is non-null and valid.
57
unsafe {
58
let val_r = right.value_unchecked(r_idx as usize);
59
*out.get_unchecked_mut(i) = r_idx;
60
*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);
61
}
62
}
63
}
64
}
65
}
66
67
let bitmap = Bitmap::try_new(mask, out.len()).unwrap();
68
IdxCa::from_vec_validity(PlSmallStr::EMPTY, out, Some(bitmap))
69
}
70
71
fn join_asof_forward<'a, T, F>(
72
left: &'a T::Array,
73
right: &'a T::Array,
74
filter: F,
75
allow_eq: bool,
76
) -> IdxCa
77
where
78
T: PolarsDataType,
79
T::Physical<'a>: TotalOrd,
80
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
81
{
82
join_asof_impl::<'a, T, AsofJoinForwardState, _>(left, right, filter, allow_eq)
83
}
84
85
fn join_asof_backward<'a, T, F>(
86
left: &'a T::Array,
87
right: &'a T::Array,
88
filter: F,
89
allow_eq: bool,
90
) -> IdxCa
91
where
92
T: PolarsDataType,
93
T::Physical<'a>: TotalOrd,
94
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
95
{
96
join_asof_impl::<'a, T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)
97
}
98
99
fn join_asof_nearest<'a, T, F>(
100
left: &'a T::Array,
101
right: &'a T::Array,
102
filter: F,
103
allow_eq: bool,
104
) -> IdxCa
105
where
106
T: PolarsDataType,
107
T::Physical<'a>: NumericNative,
108
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
109
{
110
join_asof_impl::<'a, T, AsofJoinNearestState, _>(left, right, filter, allow_eq)
111
}
112
113
pub(crate) fn join_asof_numeric<T: PolarsNumericType>(
114
input_ca: &ChunkedArray<T>,
115
other: &Series,
116
strategy: AsofStrategy,
117
tolerance: Option<AnyValue<'static>>,
118
allow_eq: bool,
119
) -> PolarsResult<IdxCa> {
120
let other = input_ca.unpack_series_matching_type(other)?;
121
122
let ca = input_ca.rechunk();
123
let other = other.rechunk();
124
let left = ca.downcast_as_array();
125
let right = other.downcast_as_array();
126
127
let out = if let Some(t) = tolerance {
128
let native_tolerance = t.try_extract::<T::Native>()?;
129
let abs_tolerance = native_tolerance.abs_diff(T::Native::zero());
130
let filter = |l: T::Native, r: T::Native| l.abs_diff(r) <= abs_tolerance;
131
match strategy {
132
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),
133
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),
134
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),
135
}
136
} else {
137
let filter = |_l: T::Native, _r: T::Native| true;
138
match strategy {
139
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),
140
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),
141
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),
142
}
143
};
144
Ok(out)
145
}
146
147
pub(crate) fn join_asof<T>(
148
input_ca: &ChunkedArray<T>,
149
other: &Series,
150
strategy: AsofStrategy,
151
allow_eq: bool,
152
) -> PolarsResult<IdxCa>
153
where
154
T: PolarsDataType,
155
for<'a> T::Physical<'a>: TotalOrd,
156
{
157
let other = input_ca.unpack_series_matching_type(other)?;
158
159
let ca = input_ca.rechunk();
160
let other = other.rechunk();
161
let left = ca.downcast_iter().next().unwrap();
162
let right = other.downcast_iter().next().unwrap();
163
164
let filter = |_l: T::Physical<'_>, _r: T::Physical<'_>| true;
165
Ok(match strategy {
166
AsofStrategy::Forward => {
167
join_asof_impl::<T, AsofJoinForwardState, _>(left, right, filter, allow_eq)
168
},
169
AsofStrategy::Backward => {
170
join_asof_impl::<T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)
171
},
172
AsofStrategy::Nearest => polars_bail!(InvalidOperation:
173
"AsOf strategy \"nearest\" is not supported for {} data type",
174
T::get_static_dtype()
175
),
176
})
177
}
178
179
#[cfg(test)]
180
mod test {
181
use arrow::array::PrimitiveArray;
182
183
use super::*;
184
185
#[test]
186
fn test_asof_backward() {
187
let a = PrimitiveArray::from_slice([-1, 2, 3, 3, 3, 4]);
188
let b = PrimitiveArray::from_slice([1, 2, 3, 3]);
189
190
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
191
assert_eq!(tuples.len(), a.len());
192
assert_eq!(
193
tuples.to_vec(),
194
&[None, Some(1), Some(3), Some(3), Some(3), Some(3)]
195
);
196
197
let b = PrimitiveArray::from_slice([1, 2, 4, 5]);
198
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
199
assert_eq!(
200
tuples.to_vec(),
201
&[None, Some(1), Some(1), Some(1), Some(1), Some(2)]
202
);
203
204
let a = PrimitiveArray::from_slice([2, 4, 4, 4]);
205
let b = PrimitiveArray::from_slice([1, 2, 3, 3]);
206
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
207
assert_eq!(tuples.to_vec(), &[Some(1), Some(3), Some(3), Some(3)]);
208
}
209
210
#[test]
211
fn test_asof_backward_tolerance() {
212
let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40]);
213
let b = PrimitiveArray::from_slice([10, 20, 30, 30]);
214
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);
215
assert_eq!(
216
tuples.to_vec(),
217
&[None, Some(1), None, Some(3), Some(3), None]
218
);
219
}
220
221
#[test]
222
fn test_asof_forward_tolerance() {
223
let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40, 52]);
224
let b = PrimitiveArray::from_slice([10, 20, 33, 55]);
225
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);
226
assert_eq!(
227
tuples.to_vec(),
228
&[None, Some(1), None, Some(2), Some(2), None, Some(3)]
229
);
230
}
231
232
#[test]
233
fn test_asof_forward() {
234
let a = PrimitiveArray::from_slice([-1, 1, 2, 4, 6]);
235
let b = PrimitiveArray::from_slice([1, 2, 4, 5]);
236
237
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |_, _| true, true);
238
assert_eq!(tuples.len(), a.len());
239
assert_eq!(tuples.to_vec(), &[Some(0), Some(0), Some(1), Some(2), None]);
240
}
241
}
242
243