Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/asof/default.rs
6940 views
1
use arrow::array::Array;
2
use arrow::bitmap::Bitmap;
3
use num_traits::Zero;
4
use polars_core::prelude::*;
5
use polars_utils::abs_diff::AbsDiff;
6
7
use super::{
8
AsofJoinBackwardState, AsofJoinForwardState, AsofJoinNearestState, AsofJoinState, AsofStrategy,
9
};
10
11
fn join_asof_impl<'a, T, S, F>(
12
left: &'a T::Array,
13
right: &'a T::Array,
14
mut filter: F,
15
allow_eq: bool,
16
) -> IdxCa
17
where
18
T: PolarsDataType,
19
S: AsofJoinState<T::Physical<'a>>,
20
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
21
{
22
if left.len() == left.null_count() || right.len() == right.null_count() {
23
return IdxCa::full_null(PlSmallStr::EMPTY, left.len());
24
}
25
26
let mut out = vec![0; left.len()];
27
let mut mask = vec![0; left.len().div_ceil(8)];
28
let mut state = S::new(allow_eq);
29
30
if left.null_count() == 0 && right.null_count() == 0 {
31
for (i, val_l) in left.values_iter().enumerate() {
32
if let Some(r_idx) = state.next(
33
&val_l,
34
// SAFETY: next() only calls with indices < right.len().
35
|j| Some(unsafe { right.value_unchecked(j as usize) }),
36
right.len() as IdxSize,
37
) {
38
// SAFETY: r_idx is non-null and valid.
39
unsafe {
40
let val_r = right.value_unchecked(r_idx as usize);
41
*out.get_unchecked_mut(i) = r_idx;
42
*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);
43
}
44
}
45
}
46
} else {
47
for (i, opt_val_l) in left.iter().enumerate() {
48
if let Some(val_l) = opt_val_l {
49
if let Some(r_idx) = state.next(
50
&val_l,
51
// SAFETY: next() only calls with indices < right.len().
52
|j| unsafe { right.get_unchecked(j as usize) },
53
right.len() as IdxSize,
54
) {
55
// SAFETY: r_idx is non-null and valid.
56
unsafe {
57
let val_r = right.value_unchecked(r_idx as usize);
58
*out.get_unchecked_mut(i) = r_idx;
59
*mask.get_unchecked_mut(i / 8) |= (filter(val_l, val_r) as u8) << (i % 8);
60
}
61
}
62
}
63
}
64
}
65
66
let bitmap = Bitmap::try_new(mask, out.len()).unwrap();
67
IdxCa::from_vec_validity(PlSmallStr::EMPTY, out, Some(bitmap))
68
}
69
70
fn join_asof_forward<'a, T, F>(
71
left: &'a T::Array,
72
right: &'a T::Array,
73
filter: F,
74
allow_eq: bool,
75
) -> IdxCa
76
where
77
T: PolarsDataType,
78
T::Physical<'a>: PartialOrd,
79
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
80
{
81
join_asof_impl::<'a, T, AsofJoinForwardState, _>(left, right, filter, allow_eq)
82
}
83
84
fn join_asof_backward<'a, T, F>(
85
left: &'a T::Array,
86
right: &'a T::Array,
87
filter: F,
88
allow_eq: bool,
89
) -> IdxCa
90
where
91
T: PolarsDataType,
92
T::Physical<'a>: PartialOrd,
93
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
94
{
95
join_asof_impl::<'a, T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)
96
}
97
98
fn join_asof_nearest<'a, T, F>(
99
left: &'a T::Array,
100
right: &'a T::Array,
101
filter: F,
102
allow_eq: bool,
103
) -> IdxCa
104
where
105
T: PolarsDataType,
106
T::Physical<'a>: NumericNative,
107
F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool,
108
{
109
join_asof_impl::<'a, T, AsofJoinNearestState, _>(left, right, filter, allow_eq)
110
}
111
112
pub(crate) fn join_asof_numeric<T: PolarsNumericType>(
113
input_ca: &ChunkedArray<T>,
114
other: &Series,
115
strategy: AsofStrategy,
116
tolerance: Option<AnyValue<'static>>,
117
allow_eq: bool,
118
) -> PolarsResult<IdxCa> {
119
let other = input_ca.unpack_series_matching_type(other)?;
120
121
let ca = input_ca.rechunk();
122
let other = other.rechunk();
123
let left = ca.downcast_as_array();
124
let right = other.downcast_as_array();
125
126
let out = if let Some(t) = tolerance {
127
let native_tolerance = t.try_extract::<T::Native>()?;
128
let abs_tolerance = native_tolerance.abs_diff(T::Native::zero());
129
let filter = |l: T::Native, r: T::Native| l.abs_diff(r) <= abs_tolerance;
130
match strategy {
131
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),
132
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),
133
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),
134
}
135
} else {
136
let filter = |_l: T::Native, _r: T::Native| true;
137
match strategy {
138
AsofStrategy::Forward => join_asof_forward::<T, _>(left, right, filter, allow_eq),
139
AsofStrategy::Backward => join_asof_backward::<T, _>(left, right, filter, allow_eq),
140
AsofStrategy::Nearest => join_asof_nearest::<T, _>(left, right, filter, allow_eq),
141
}
142
};
143
Ok(out)
144
}
145
146
pub(crate) fn join_asof<T>(
147
input_ca: &ChunkedArray<T>,
148
other: &Series,
149
strategy: AsofStrategy,
150
allow_eq: bool,
151
) -> PolarsResult<IdxCa>
152
where
153
T: PolarsDataType,
154
for<'a> T::Physical<'a>: PartialOrd,
155
{
156
let other = input_ca.unpack_series_matching_type(other)?;
157
158
let ca = input_ca.rechunk();
159
let other = other.rechunk();
160
let left = ca.downcast_iter().next().unwrap();
161
let right = other.downcast_iter().next().unwrap();
162
163
let filter = |_l: T::Physical<'_>, _r: T::Physical<'_>| true;
164
Ok(match strategy {
165
AsofStrategy::Forward => {
166
join_asof_impl::<T, AsofJoinForwardState, _>(left, right, filter, allow_eq)
167
},
168
AsofStrategy::Backward => {
169
join_asof_impl::<T, AsofJoinBackwardState, _>(left, right, filter, allow_eq)
170
},
171
AsofStrategy::Nearest => unimplemented!(),
172
})
173
}
174
175
#[cfg(test)]
176
mod test {
177
use arrow::array::PrimitiveArray;
178
179
use super::*;
180
181
#[test]
182
fn test_asof_backward() {
183
let a = PrimitiveArray::from_slice([-1, 2, 3, 3, 3, 4]);
184
let b = PrimitiveArray::from_slice([1, 2, 3, 3]);
185
186
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
187
assert_eq!(tuples.len(), a.len());
188
assert_eq!(
189
tuples.to_vec(),
190
&[None, Some(1), Some(3), Some(3), Some(3), Some(3)]
191
);
192
193
let b = PrimitiveArray::from_slice([1, 2, 4, 5]);
194
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
195
assert_eq!(
196
tuples.to_vec(),
197
&[None, Some(1), Some(1), Some(1), Some(1), Some(2)]
198
);
199
200
let a = PrimitiveArray::from_slice([2, 4, 4, 4]);
201
let b = PrimitiveArray::from_slice([1, 2, 3, 3]);
202
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |_, _| true, true);
203
assert_eq!(tuples.to_vec(), &[Some(1), Some(3), Some(3), Some(3)]);
204
}
205
206
#[test]
207
fn test_asof_backward_tolerance() {
208
let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40]);
209
let b = PrimitiveArray::from_slice([10, 20, 30, 30]);
210
let tuples = join_asof_backward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);
211
assert_eq!(
212
tuples.to_vec(),
213
&[None, Some(1), None, Some(3), Some(3), None]
214
);
215
}
216
217
#[test]
218
fn test_asof_forward_tolerance() {
219
let a = PrimitiveArray::from_slice([-1, 20, 25, 30, 30, 40, 52]);
220
let b = PrimitiveArray::from_slice([10, 20, 33, 55]);
221
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |l, r| l.abs_diff(r) <= 4u32, true);
222
assert_eq!(
223
tuples.to_vec(),
224
&[None, Some(1), None, Some(2), Some(2), None, Some(3)]
225
);
226
}
227
228
#[test]
229
fn test_asof_forward() {
230
let a = PrimitiveArray::from_slice([-1, 1, 2, 4, 6]);
231
let b = PrimitiveArray::from_slice([1, 2, 4, 5]);
232
233
let tuples = join_asof_forward::<Int32Type, _>(&a, &b, |_, _| true, true);
234
assert_eq!(tuples.len(), a.len());
235
assert_eq!(tuples.to_vec(), &[Some(0), Some(0), Some(1), Some(2), None]);
236
}
237
}
238
239