Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/idx_table/single_key.rs
6940 views
1
#![allow(clippy::unnecessary_cast)] // Clippy doesn't recognize that IdxSize and u64 can be different.
2
#![allow(unsafe_op_in_unsafe_fn)]
3
4
use polars_utils::idx_map::total_idx_map::{Entry, TotalIndexMap};
5
use polars_utils::idx_vec::UnitVec;
6
use polars_utils::itertools::Itertools;
7
use polars_utils::relaxed_cell::RelaxedCell;
8
use polars_utils::total_ord::{TotalEq, TotalHash};
9
use polars_utils::unitvec;
10
11
use super::*;
12
use crate::hash_keys::HashKeys;
13
14
pub struct SingleKeyIdxTable<T: PolarsDataType> {
15
// These AtomicU64s actually are IdxSizes, but we use the top bit of the
16
// first index in each to mark keys during probing.
17
idx_map: TotalIndexMap<T::Physical<'static>, UnitVec<RelaxedCell<u64>>>,
18
idx_offset: IdxSize,
19
null_keys: Vec<IdxSize>,
20
nulls_emitted: RelaxedCell<bool>,
21
}
22
23
impl<T: PolarsDataType> SingleKeyIdxTable<T> {
24
pub fn new() -> Self {
25
Self {
26
idx_map: TotalIndexMap::default(),
27
idx_offset: 0,
28
null_keys: Vec::new(),
29
nulls_emitted: RelaxedCell::from(false),
30
}
31
}
32
}
33
34
impl<T, K> SingleKeyIdxTable<T>
35
where
36
for<'a> T: PolarsDataType<Physical<'a> = K>,
37
K: TotalHash + TotalEq + Send + Sync + 'static,
38
{
39
#[inline(always)]
40
fn probe_one<const MARK_MATCHES: bool>(
41
&self,
42
key_idx: IdxSize,
43
key: &K,
44
table_match: &mut Vec<IdxSize>,
45
probe_match: &mut Vec<IdxSize>,
46
) -> bool {
47
if let Some(idxs) = self.idx_map.get(key) {
48
for idx in &idxs[..] {
49
// Create matches, making sure to clear top bit.
50
table_match.push((idx.load() & !(1 << 63)) as IdxSize);
51
probe_match.push(key_idx);
52
}
53
54
// Mark if necessary. This action is idempotent so doesn't need
55
// atomic fetch_or to do it atomically.
56
if MARK_MATCHES {
57
let first_idx = unsafe { idxs.get_unchecked(0) };
58
let first_idx_val = first_idx.load();
59
if first_idx_val >> 63 == 0 {
60
first_idx.store(first_idx_val | (1 << 63));
61
}
62
}
63
true
64
} else {
65
false
66
}
67
}
68
69
fn probe_impl<
70
const MARK_MATCHES: bool,
71
const EMIT_UNMATCHED: bool,
72
const NULL_IS_VALID: bool,
73
>(
74
&self,
75
keys: impl Iterator<Item = (IdxSize, Option<K>)>,
76
table_match: &mut Vec<IdxSize>,
77
probe_match: &mut Vec<IdxSize>,
78
limit: IdxSize,
79
) -> IdxSize {
80
let mut keys_processed = 0;
81
for (key_idx, key) in keys {
82
let found_match = if let Some(key) = key {
83
self.probe_one::<MARK_MATCHES>(key_idx, &key, table_match, probe_match)
84
} else if NULL_IS_VALID {
85
for idx in &self.null_keys {
86
table_match.push(*idx);
87
probe_match.push(key_idx);
88
}
89
if MARK_MATCHES && !self.nulls_emitted.load() {
90
self.nulls_emitted.store(true);
91
}
92
!self.null_keys.is_empty()
93
} else {
94
false
95
};
96
97
if EMIT_UNMATCHED && !found_match {
98
table_match.push(IdxSize::MAX);
99
probe_match.push(key_idx);
100
}
101
102
keys_processed += 1;
103
if table_match.len() >= limit as usize {
104
break;
105
}
106
}
107
keys_processed
108
}
109
110
#[allow(clippy::too_many_arguments)]
111
fn probe_dispatch(
112
&self,
113
keys: impl Iterator<Item = (IdxSize, Option<K>)>,
114
table_match: &mut Vec<IdxSize>,
115
probe_match: &mut Vec<IdxSize>,
116
mark_matches: bool,
117
emit_unmatched: bool,
118
null_is_valid: bool,
119
limit: IdxSize,
120
) -> IdxSize {
121
match (mark_matches, emit_unmatched, null_is_valid) {
122
(false, false, false) => {
123
self.probe_impl::<false, false, false>(keys, table_match, probe_match, limit)
124
},
125
(false, false, true) => {
126
self.probe_impl::<false, false, true>(keys, table_match, probe_match, limit)
127
},
128
(false, true, false) => {
129
self.probe_impl::<false, true, false>(keys, table_match, probe_match, limit)
130
},
131
(false, true, true) => {
132
self.probe_impl::<false, true, true>(keys, table_match, probe_match, limit)
133
},
134
(true, false, false) => {
135
self.probe_impl::<true, false, false>(keys, table_match, probe_match, limit)
136
},
137
(true, false, true) => {
138
self.probe_impl::<true, false, true>(keys, table_match, probe_match, limit)
139
},
140
(true, true, false) => {
141
self.probe_impl::<true, true, false>(keys, table_match, probe_match, limit)
142
},
143
(true, true, true) => {
144
self.probe_impl::<true, true, true>(keys, table_match, probe_match, limit)
145
},
146
}
147
}
148
}
149
150
impl<T, K> IdxTable for SingleKeyIdxTable<T>
151
where
152
for<'a> T: PolarsDataType<Physical<'a> = K>,
153
K: TotalHash + TotalEq + Send + Sync + 'static,
154
{
155
fn new_empty(&self) -> Box<dyn IdxTable> {
156
Box::new(Self::new())
157
}
158
159
fn reserve(&mut self, additional: usize) {
160
self.idx_map.reserve(additional);
161
}
162
163
fn num_keys(&self) -> IdxSize {
164
self.idx_map.len()
165
}
166
167
fn insert_keys(&mut self, _hash_keys: &HashKeys, _track_unmatchable: bool) {
168
// Isn't needed anymore, but also don't want to remove the code from the other implementations.
169
unimplemented!()
170
}
171
172
unsafe fn insert_keys_subset(
173
&mut self,
174
hash_keys: &HashKeys,
175
subset: &[IdxSize],
176
track_unmatchable: bool,
177
) {
178
let HashKeys::Single(hash_keys) = hash_keys else {
179
unreachable!()
180
};
181
let new_idx_offset = (self.idx_offset as usize)
182
.checked_add(subset.len())
183
.unwrap();
184
assert!(
185
new_idx_offset < IdxSize::MAX as usize,
186
"overly large index in SingleKeyIdxTable"
187
);
188
189
let keys: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();
190
for (i, subset_idx) in subset.iter().enumerate_idx() {
191
let key = unsafe { keys.get_unchecked(*subset_idx as usize) };
192
let idx = self.idx_offset + i;
193
if let Some(key) = key {
194
match self.idx_map.entry(key) {
195
Entry::Occupied(o) => {
196
o.into_mut().push(RelaxedCell::from(idx as u64));
197
},
198
Entry::Vacant(v) => {
199
v.insert(unitvec![RelaxedCell::from(idx as u64)]);
200
},
201
}
202
} else if track_unmatchable | hash_keys.null_is_valid {
203
self.null_keys.push(idx);
204
}
205
}
206
207
self.idx_offset = new_idx_offset as IdxSize;
208
}
209
210
fn probe(
211
&self,
212
_hash_keys: &HashKeys,
213
_table_match: &mut Vec<IdxSize>,
214
_probe_match: &mut Vec<IdxSize>,
215
_mark_matches: bool,
216
_emit_unmatched: bool,
217
_limit: IdxSize,
218
) -> IdxSize {
219
// Isn't needed anymore, but also don't want to remove the code from the other implementations.
220
unimplemented!()
221
}
222
223
unsafe fn probe_subset(
224
&self,
225
hash_keys: &HashKeys,
226
subset: &[IdxSize],
227
table_match: &mut Vec<IdxSize>,
228
probe_match: &mut Vec<IdxSize>,
229
mark_matches: bool,
230
emit_unmatched: bool,
231
limit: IdxSize,
232
) -> IdxSize {
233
let HashKeys::Single(hash_keys) = hash_keys else {
234
unreachable!()
235
};
236
237
let keys: &ChunkedArray<T> = hash_keys.keys.as_phys_any().downcast_ref().unwrap();
238
if keys.has_nulls() {
239
let iter = subset.iter().map(|i| (*i, keys.get_unchecked(*i as usize)));
240
self.probe_dispatch(
241
iter,
242
table_match,
243
probe_match,
244
mark_matches,
245
emit_unmatched,
246
hash_keys.null_is_valid,
247
limit,
248
)
249
} else {
250
let iter = subset
251
.iter()
252
.map(|i| (*i, Some(keys.value_unchecked(*i as usize))));
253
self.probe_dispatch(
254
iter,
255
table_match,
256
probe_match,
257
mark_matches,
258
emit_unmatched,
259
false, // Whether or not nulls are valid doesn't matter.
260
limit,
261
)
262
}
263
}
264
265
fn unmarked_keys(
266
&self,
267
out: &mut Vec<IdxSize>,
268
mut offset: IdxSize,
269
limit: IdxSize,
270
) -> IdxSize {
271
out.clear();
272
273
let mut keys_processed = 0;
274
if !self.nulls_emitted.load() {
275
if (offset as usize) < self.null_keys.len() {
276
out.extend(
277
self.null_keys[offset as usize..]
278
.iter()
279
.copied()
280
.take(limit as usize),
281
);
282
keys_processed += out.len() as IdxSize;
283
offset += out.len() as IdxSize;
284
if out.len() >= limit as usize {
285
return keys_processed;
286
}
287
}
288
offset -= self.null_keys.len() as IdxSize;
289
}
290
291
while let Some((_, idxs)) = self.idx_map.get_index(offset) {
292
let first_idx = unsafe { idxs.get_unchecked(0) };
293
let first_idx_val = first_idx.load();
294
if first_idx_val >> 63 == 0 {
295
for idx in &idxs[..] {
296
out.push((idx.load() & !(1 << 63)) as IdxSize);
297
}
298
}
299
300
keys_processed += 1;
301
offset += 1;
302
if out.len() >= limit as usize {
303
break;
304
}
305
}
306
307
keys_processed
308
}
309
}
310
311