Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/idx_table/row_encoded.rs
6940 views
1
#![allow(clippy::unnecessary_cast)] // Clippy doesn't recognize that IdxSize and u64 can be different.
2
#![allow(unsafe_op_in_unsafe_fn)]
3
4
use arrow::array::Array;
5
use polars_compute::binview_index_map::{BinaryViewIndexMap, Entry};
6
use polars_utils::idx_vec::UnitVec;
7
use polars_utils::itertools::Itertools;
8
use polars_utils::relaxed_cell::RelaxedCell;
9
use polars_utils::unitvec;
10
11
use super::*;
12
use crate::hash_keys::HashKeys;
13
14
#[derive(Default)]
15
pub struct RowEncodedIdxTable {
16
// These AtomicU64s actually are IdxSizes, but we use the top bit of the
17
// first index in each to mark keys during probing.
18
idx_map: BinaryViewIndexMap<UnitVec<RelaxedCell<u64>>>,
19
idx_offset: IdxSize,
20
null_keys: Vec<IdxSize>,
21
}
22
23
impl RowEncodedIdxTable {
24
pub fn new() -> Self {
25
Self {
26
idx_map: BinaryViewIndexMap::new(),
27
idx_offset: 0,
28
null_keys: Vec::new(),
29
}
30
}
31
}
32
33
impl RowEncodedIdxTable {
34
#[inline(always)]
35
fn probe_one<const MARK_MATCHES: bool>(
36
&self,
37
key_idx: IdxSize,
38
hash: u64,
39
key: &[u8],
40
table_match: &mut Vec<IdxSize>,
41
probe_match: &mut Vec<IdxSize>,
42
) -> bool {
43
if let Some(idxs) = self.idx_map.get(hash, key) {
44
for idx in &idxs[..] {
45
// Create matches, making sure to clear top bit.
46
table_match.push((idx.load() & !(1 << 63)) as IdxSize);
47
probe_match.push(key_idx);
48
}
49
50
// Mark if necessary. This action is idempotent so doesn't
51
// fetch_or to do it atomically.
52
if MARK_MATCHES {
53
let first_idx = unsafe { idxs.get_unchecked(0) };
54
let first_idx_val = first_idx.load();
55
if first_idx_val >> 63 == 0 {
56
first_idx.store(first_idx_val | (1 << 63));
57
}
58
}
59
true
60
} else {
61
false
62
}
63
}
64
65
fn probe_impl<'a, const MARK_MATCHES: bool, const EMIT_UNMATCHED: bool>(
66
&self,
67
hash_keys: impl Iterator<Item = (IdxSize, u64, Option<&'a [u8]>)>,
68
table_match: &mut Vec<IdxSize>,
69
probe_match: &mut Vec<IdxSize>,
70
limit: IdxSize,
71
) -> IdxSize {
72
let mut keys_processed = 0;
73
for (key_idx, hash, key) in hash_keys {
74
let found_match = if let Some(key) = key {
75
self.probe_one::<MARK_MATCHES>(key_idx, hash, key, table_match, probe_match)
76
} else {
77
false
78
};
79
80
if EMIT_UNMATCHED && !found_match {
81
table_match.push(IdxSize::MAX);
82
probe_match.push(key_idx);
83
}
84
85
keys_processed += 1;
86
if table_match.len() >= limit as usize {
87
break;
88
}
89
}
90
keys_processed
91
}
92
93
fn probe_dispatch<'a>(
94
&self,
95
hash_keys: impl Iterator<Item = (IdxSize, u64, Option<&'a [u8]>)>,
96
table_match: &mut Vec<IdxSize>,
97
probe_match: &mut Vec<IdxSize>,
98
mark_matches: bool,
99
emit_unmatched: bool,
100
limit: IdxSize,
101
) -> IdxSize {
102
match (mark_matches, emit_unmatched) {
103
(false, false) => {
104
self.probe_impl::<false, false>(hash_keys, table_match, probe_match, limit)
105
},
106
(false, true) => {
107
self.probe_impl::<false, true>(hash_keys, table_match, probe_match, limit)
108
},
109
(true, false) => {
110
self.probe_impl::<true, false>(hash_keys, table_match, probe_match, limit)
111
},
112
(true, true) => {
113
self.probe_impl::<true, true>(hash_keys, table_match, probe_match, limit)
114
},
115
}
116
}
117
}
118
119
impl IdxTable for RowEncodedIdxTable {
120
fn new_empty(&self) -> Box<dyn IdxTable> {
121
Box::new(Self::new())
122
}
123
124
fn reserve(&mut self, additional: usize) {
125
self.idx_map.reserve(additional);
126
}
127
128
fn num_keys(&self) -> IdxSize {
129
self.idx_map.len()
130
}
131
132
fn insert_keys(&mut self, hash_keys: &HashKeys, track_unmatchable: bool) {
133
let HashKeys::RowEncoded(hash_keys) = hash_keys else {
134
unreachable!()
135
};
136
let new_idx_offset = (self.idx_offset as usize)
137
.checked_add(hash_keys.keys.len())
138
.unwrap();
139
assert!(
140
new_idx_offset < IdxSize::MAX as usize,
141
"overly large index in RowEncodedIdxTable"
142
);
143
144
for (i, (hash, key)) in hash_keys
145
.hashes
146
.values_iter()
147
.zip(hash_keys.keys.iter())
148
.enumerate_idx()
149
{
150
let idx = self.idx_offset + i;
151
if let Some(key) = key {
152
match self.idx_map.entry(*hash, key) {
153
Entry::Occupied(o) => {
154
o.into_mut().push(RelaxedCell::from(idx as u64));
155
},
156
Entry::Vacant(v) => {
157
v.insert(unitvec![RelaxedCell::from(idx as u64)]);
158
},
159
}
160
} else if track_unmatchable {
161
self.null_keys.push(idx);
162
}
163
}
164
165
self.idx_offset = new_idx_offset as IdxSize;
166
}
167
168
unsafe fn insert_keys_subset(
169
&mut self,
170
hash_keys: &HashKeys,
171
subset: &[IdxSize],
172
track_unmatchable: bool,
173
) {
174
let HashKeys::RowEncoded(hash_keys) = hash_keys else {
175
unreachable!()
176
};
177
let new_idx_offset = (self.idx_offset as usize)
178
.checked_add(subset.len())
179
.unwrap();
180
assert!(
181
new_idx_offset < IdxSize::MAX as usize,
182
"overly large index in RowEncodedIdxTable"
183
);
184
185
for (i, subset_idx) in subset.iter().enumerate_idx() {
186
let hash = unsafe { hash_keys.hashes.value_unchecked(*subset_idx as usize) };
187
let key = unsafe { hash_keys.keys.get_unchecked(*subset_idx as usize) };
188
let idx = self.idx_offset + i;
189
if let Some(key) = key {
190
match self.idx_map.entry(hash, key) {
191
Entry::Occupied(o) => {
192
o.into_mut().push(RelaxedCell::from(idx as u64));
193
},
194
Entry::Vacant(v) => {
195
v.insert(unitvec![RelaxedCell::from(idx as u64)]);
196
},
197
}
198
} else if track_unmatchable {
199
self.null_keys.push(idx);
200
}
201
}
202
203
self.idx_offset = new_idx_offset as IdxSize;
204
}
205
206
fn probe(
207
&self,
208
hash_keys: &HashKeys,
209
table_match: &mut Vec<IdxSize>,
210
probe_match: &mut Vec<IdxSize>,
211
mark_matches: bool,
212
emit_unmatched: bool,
213
limit: IdxSize,
214
) -> IdxSize {
215
let HashKeys::RowEncoded(hash_keys) = hash_keys else {
216
unreachable!()
217
};
218
219
if hash_keys.keys.has_nulls() {
220
let iter = hash_keys
221
.hashes
222
.values_iter()
223
.copied()
224
.zip(hash_keys.keys.iter())
225
.enumerate_idx()
226
.map(|(i, (h, k))| (i, h, k));
227
self.probe_dispatch(
228
iter,
229
table_match,
230
probe_match,
231
mark_matches,
232
emit_unmatched,
233
limit,
234
)
235
} else {
236
let iter = hash_keys
237
.hashes
238
.values_iter()
239
.copied()
240
.zip(hash_keys.keys.values_iter().map(Some))
241
.enumerate_idx()
242
.map(|(i, (h, k))| (i, h, k));
243
self.probe_dispatch(
244
iter,
245
table_match,
246
probe_match,
247
mark_matches,
248
emit_unmatched,
249
limit,
250
)
251
}
252
}
253
254
unsafe fn probe_subset(
255
&self,
256
hash_keys: &HashKeys,
257
subset: &[IdxSize],
258
table_match: &mut Vec<IdxSize>,
259
probe_match: &mut Vec<IdxSize>,
260
mark_matches: bool,
261
emit_unmatched: bool,
262
limit: IdxSize,
263
) -> IdxSize {
264
let HashKeys::RowEncoded(hash_keys) = hash_keys else {
265
unreachable!()
266
};
267
268
if hash_keys.keys.has_nulls() {
269
let iter = subset.iter().map(|i| {
270
(
271
*i,
272
hash_keys.hashes.value_unchecked(*i as usize),
273
hash_keys.keys.get_unchecked(*i as usize),
274
)
275
});
276
self.probe_dispatch(
277
iter,
278
table_match,
279
probe_match,
280
mark_matches,
281
emit_unmatched,
282
limit,
283
)
284
} else {
285
let iter = subset.iter().map(|i| {
286
(
287
*i,
288
hash_keys.hashes.value_unchecked(*i as usize),
289
Some(hash_keys.keys.value_unchecked(*i as usize)),
290
)
291
});
292
self.probe_dispatch(
293
iter,
294
table_match,
295
probe_match,
296
mark_matches,
297
emit_unmatched,
298
limit,
299
)
300
}
301
}
302
303
fn unmarked_keys(
304
&self,
305
out: &mut Vec<IdxSize>,
306
mut offset: IdxSize,
307
limit: IdxSize,
308
) -> IdxSize {
309
out.clear();
310
311
let mut keys_processed = 0;
312
if (offset as usize) < self.null_keys.len() {
313
out.extend(
314
self.null_keys[offset as usize..]
315
.iter()
316
.copied()
317
.take(limit as usize),
318
);
319
keys_processed += out.len() as IdxSize;
320
offset += out.len() as IdxSize;
321
if out.len() >= limit as usize {
322
return keys_processed;
323
}
324
}
325
326
offset -= self.null_keys.len() as IdxSize;
327
328
while let Some((_, _, idxs)) = self.idx_map.get_index(offset) {
329
let first_idx = unsafe { idxs.get_unchecked(0) };
330
let first_idx_val = first_idx.load();
331
if first_idx_val >> 63 == 0 {
332
for idx in &idxs[..] {
333
out.push((idx.load() & !(1 << 63)) as IdxSize);
334
}
335
}
336
337
keys_processed += 1;
338
offset += 1;
339
if out.len() >= limit as usize {
340
break;
341
}
342
}
343
344
keys_processed
345
}
346
}
347
348