Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/frame/join/iejoin/l1_l2.rs
6940 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use polars_core::chunked_array::ChunkedArray;
3
use polars_core::datatypes::{IdxCa, PolarsNumericType};
4
use polars_core::prelude::Series;
5
use polars_core::with_match_physical_numeric_polars_type;
6
use polars_error::PolarsResult;
7
use polars_utils::IdxSize;
8
use polars_utils::total_ord::TotalOrd;
9
10
use super::*;
11
12
/// Create a vector of L1 items from the array of LHS x values concatenated with RHS x values
13
/// and their ordering.
14
pub(super) fn build_l1_array<T>(
15
ca: &ChunkedArray<T>,
16
order: &IdxCa,
17
right_df_offset: IdxSize,
18
) -> PolarsResult<Vec<L1Item<T::Native>>>
19
where
20
T: PolarsNumericType,
21
{
22
assert_eq!(order.null_count(), 0);
23
assert_eq!(ca.chunks().len(), 1);
24
let arr = ca.downcast_get(0).unwrap();
25
// Even if there are nulls, they will not be selected by order.
26
let values = arr.values().as_slice();
27
28
let mut array: Vec<L1Item<T::Native>> = Vec::with_capacity(ca.len());
29
30
for order_arr in order.downcast_iter() {
31
for index in order_arr.values().as_slice().iter().copied() {
32
debug_assert!(arr.get(index as usize).is_some());
33
let value = unsafe { *values.get_unchecked(index as usize) };
34
let row_index = if index < right_df_offset {
35
// Row from LHS
36
index as i64 + 1
37
} else {
38
// Row from RHS
39
-((index - right_df_offset) as i64) - 1
40
};
41
array.push(L1Item { row_index, value });
42
}
43
}
44
45
Ok(array)
46
}
47
48
pub(super) fn build_l2_array(s: &Series, order: &[IdxSize]) -> PolarsResult<Vec<L2Item>> {
49
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
50
build_l2_array_impl::<$T>(s.as_ref().as_ref(), order)
51
})
52
}
53
54
/// Create a vector of L2 items from the array of y values ordered according to the L1 order,
55
/// and their ordering. We don't need to store actual y values but only track whether we're at
56
/// the end of a run of equal values.
57
fn build_l2_array_impl<T>(ca: &ChunkedArray<T>, order: &[IdxSize]) -> PolarsResult<Vec<L2Item>>
58
where
59
T: PolarsNumericType,
60
T::Native: TotalOrd,
61
{
62
assert_eq!(ca.chunks().len(), 1);
63
64
let mut array = Vec::with_capacity(ca.len());
65
let mut prev_index = 0;
66
let mut prev_value = T::Native::default();
67
68
let arr = ca.downcast_get(0).unwrap();
69
// Even if there are nulls, they will not be selected by order.
70
let values = arr.values().as_slice();
71
72
for (i, l1_index) in order.iter().copied().enumerate() {
73
debug_assert!(arr.get(l1_index as usize).is_some());
74
let value = unsafe { *values.get_unchecked(l1_index as usize) };
75
if i > 0 {
76
array.push(L2Item {
77
l1_index: prev_index,
78
run_end: value.tot_ne(&prev_value),
79
});
80
}
81
prev_index = l1_index;
82
prev_value = value;
83
}
84
if !order.is_empty() {
85
array.push(L2Item {
86
l1_index: prev_index,
87
run_end: true,
88
});
89
}
90
Ok(array)
91
}
92
93
/// Item in L1 array used in the IEJoin algorithm
94
#[derive(Clone, Copy, Debug)]
95
pub(super) struct L1Item<T> {
96
/// 1 based index for entries from the LHS df, or -1 based index for entries from the RHS
97
pub(super) row_index: i64,
98
/// X value
99
pub(super) value: T,
100
}
101
102
/// Item in L2 array used in the IEJoin algorithm
103
#[derive(Clone, Copy, Debug)]
104
pub(super) struct L2Item {
105
/// Corresponding index into the L1 array of
106
pub(super) l1_index: IdxSize,
107
/// Whether this is the end of a run of equal y values
108
pub(super) run_end: bool,
109
}
110
111
pub(super) trait L1Array {
112
unsafe fn process_entry(
113
&self,
114
l1_index: usize,
115
bit_array: &mut FilteredBitArray,
116
op1: InequalityOperator,
117
left_row_ids: &mut Vec<IdxSize>,
118
right_row_ids: &mut Vec<IdxSize>,
119
) -> i64;
120
121
unsafe fn process_lhs_entry(
122
&self,
123
l1_index: usize,
124
bit_array: &FilteredBitArray,
125
op1: InequalityOperator,
126
left_row_ids: &mut Vec<IdxSize>,
127
right_row_ids: &mut Vec<IdxSize>,
128
) -> i64;
129
130
unsafe fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray);
131
}
132
133
/// Find the position in the L1 array where we should begin checking for matches,
134
/// given the index in L1 corresponding to the current position in L2.
135
unsafe fn find_search_start_index<T>(
136
l1_array: &[L1Item<T>],
137
index: usize,
138
operator: InequalityOperator,
139
) -> usize
140
where
141
T: NumericNative,
142
T: TotalOrd,
143
{
144
let sub_l1 = l1_array.get_unchecked(index..);
145
let value = l1_array.get_unchecked(index).value;
146
147
match operator {
148
InequalityOperator::Gt => {
149
sub_l1.partition_point_exponential(|a| a.value.tot_ge(&value)) + index
150
},
151
InequalityOperator::Lt => {
152
sub_l1.partition_point_exponential(|a| a.value.tot_le(&value)) + index
153
},
154
InequalityOperator::GtEq => {
155
sub_l1.partition_point_exponential(|a| value.tot_lt(&a.value)) + index
156
},
157
InequalityOperator::LtEq => {
158
sub_l1.partition_point_exponential(|a| value.tot_gt(&a.value)) + index
159
},
160
}
161
}
162
163
fn find_matches_in_l1<T>(
164
l1_array: &[L1Item<T>],
165
l1_index: usize,
166
row_index: i64,
167
bit_array: &FilteredBitArray,
168
op1: InequalityOperator,
169
left_row_ids: &mut Vec<IdxSize>,
170
right_row_ids: &mut Vec<IdxSize>,
171
) -> i64
172
where
173
T: NumericNative,
174
T: TotalOrd,
175
{
176
debug_assert!(row_index > 0);
177
let mut match_count = 0;
178
179
// This entry comes from the left hand side DataFrame.
180
// Find all following entries in L1 (meaning they satisfy the first operator)
181
// that have already been visited (so satisfy the second operator).
182
// Because we use a stable sort for l2, we know that we won't find any
183
// matches for duplicate y values when traversing forwards in l1.
184
let start_index = unsafe { find_search_start_index(l1_array, l1_index, op1) };
185
unsafe {
186
bit_array.on_set_bits_from(start_index, |set_bit: usize| {
187
// SAFETY
188
// set bit is within bounds.
189
let right_row_index = l1_array.get_unchecked(set_bit).row_index;
190
debug_assert!(right_row_index < 0);
191
left_row_ids.push((row_index - 1) as IdxSize);
192
right_row_ids.push((-right_row_index) as IdxSize - 1);
193
match_count += 1;
194
})
195
};
196
197
match_count
198
}
199
200
impl<T> L1Array for Vec<L1Item<T>>
201
where
202
T: NumericNative,
203
{
204
unsafe fn process_entry(
205
&self,
206
l1_index: usize,
207
bit_array: &mut FilteredBitArray,
208
op1: InequalityOperator,
209
left_row_ids: &mut Vec<IdxSize>,
210
right_row_ids: &mut Vec<IdxSize>,
211
) -> i64 {
212
let row_index = self.get_unchecked(l1_index).row_index;
213
let from_lhs = row_index > 0;
214
if from_lhs {
215
find_matches_in_l1(
216
self,
217
l1_index,
218
row_index,
219
bit_array,
220
op1,
221
left_row_ids,
222
right_row_ids,
223
)
224
} else {
225
bit_array.set_bit_unchecked(l1_index);
226
0
227
}
228
}
229
230
unsafe fn process_lhs_entry(
231
&self,
232
l1_index: usize,
233
bit_array: &FilteredBitArray,
234
op1: InequalityOperator,
235
left_row_ids: &mut Vec<IdxSize>,
236
right_row_ids: &mut Vec<IdxSize>,
237
) -> i64 {
238
let row_index = self.get_unchecked(l1_index).row_index;
239
let from_lhs = row_index > 0;
240
if from_lhs {
241
find_matches_in_l1(
242
self,
243
l1_index,
244
row_index,
245
bit_array,
246
op1,
247
left_row_ids,
248
right_row_ids,
249
)
250
} else {
251
0
252
}
253
}
254
255
unsafe fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray) {
256
let from_lhs = self.get_unchecked(index).row_index > 0;
257
// We only mark RHS entries as visited,
258
// so that we don't try to match LHS entries with other LHS entries.
259
if !from_lhs {
260
bit_array.set_bit_unchecked(index);
261
}
262
}
263
}
264
265