Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/hash_keys.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::hash::BuildHasher;
3
4
use arrow::array::{Array, BinaryArray, BinaryViewArray, PrimitiveArray, StaticArray, UInt64Array};
5
use arrow::bitmap::Bitmap;
6
use arrow::compute::utils::combine_validities_and_many;
7
use polars_core::frame::DataFrame;
8
use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
9
use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType, *};
10
use polars_core::series::Series;
11
use polars_utils::IdxSize;
12
use polars_utils::cardinality_sketch::CardinalitySketch;
13
use polars_utils::hashing::HashPartitioner;
14
use polars_utils::itertools::Itertools;
15
use polars_utils::total_ord::{BuildHasherTotalExt, TotalHash};
16
use polars_utils::vec::PushUnchecked;
17
18
#[derive(PartialEq, Eq, PartialOrd, Ord)]
19
pub enum HashKeysVariant {
20
RowEncoded,
21
Single,
22
Binview,
23
}
24
25
pub fn hash_keys_variant_for_dtype(dt: &DataType) -> HashKeysVariant {
26
match dt {
27
dt if dt.is_primitive_numeric() | dt.is_temporal() => HashKeysVariant::Single,
28
29
#[cfg(feature = "dtype-decimal")]
30
DataType::Decimal(_, _) => HashKeysVariant::Single,
31
#[cfg(feature = "dtype-categorical")]
32
DataType::Enum(_, _) | DataType::Categorical(_, _) => HashKeysVariant::Single,
33
34
DataType::String | DataType::Binary => HashKeysVariant::Binview,
35
36
// TODO: more efficient encoding for these.
37
DataType::Boolean | DataType::Null => HashKeysVariant::RowEncoded,
38
39
_ => HashKeysVariant::RowEncoded,
40
}
41
}
42
43
macro_rules! downcast_single_key_ca {
44
(
45
$self:expr, | $ca:ident | $($body:tt)*
46
) => {{
47
#[allow(unused_imports)]
48
use polars_core::datatypes::DataType::*;
49
match $self.dtype() {
50
#[cfg(feature = "dtype-i8")]
51
DataType::Int8 => { let $ca = $self.i8().unwrap(); $($body)* },
52
#[cfg(feature = "dtype-i16")]
53
DataType::Int16 => { let $ca = $self.i16().unwrap(); $($body)* },
54
DataType::Int32 => { let $ca = $self.i32().unwrap(); $($body)* },
55
DataType::Int64 => { let $ca = $self.i64().unwrap(); $($body)* },
56
#[cfg(feature = "dtype-u8")]
57
DataType::UInt8 => { let $ca = $self.u8().unwrap(); $($body)* },
58
#[cfg(feature = "dtype-u16")]
59
DataType::UInt16 => { let $ca = $self.u16().unwrap(); $($body)* },
60
DataType::UInt32 => { let $ca = $self.u32().unwrap(); $($body)* },
61
DataType::UInt64 => { let $ca = $self.u64().unwrap(); $($body)* },
62
#[cfg(feature = "dtype-i128")]
63
DataType::Int128 => { let $ca = $self.i128().unwrap(); $($body)* },
64
DataType::Float32 => { let $ca = $self.f32().unwrap(); $($body)* },
65
DataType::Float64 => { let $ca = $self.f64().unwrap(); $($body)* },
66
67
#[cfg(feature = "dtype-date")]
68
DataType::Date => { let $ca = $self.date().unwrap().physical(); $($body)* },
69
#[cfg(feature = "dtype-time")]
70
DataType::Time => { let $ca = $self.time().unwrap().physical(); $($body)* },
71
#[cfg(feature = "dtype-datetime")]
72
DataType::Datetime(..) => { let $ca = $self.datetime().unwrap().physical(); $($body)* },
73
#[cfg(feature = "dtype-duration")]
74
DataType::Duration(..) => { let $ca = $self.duration().unwrap().physical(); $($body)* },
75
76
#[cfg(feature = "dtype-decimal")]
77
DataType::Decimal(..) => { let $ca = $self.decimal().unwrap().physical(); $($body)* },
78
#[cfg(feature = "dtype-categorical")]
79
dt @ (DataType::Enum(_, _) | DataType::Categorical(_, _)) => {
80
match dt.cat_physical().unwrap() {
81
CategoricalPhysical::U8 => { let $ca = $self.cat8().unwrap().physical(); $($body)* },
82
CategoricalPhysical::U16 => { let $ca = $self.cat16().unwrap().physical(); $($body)* },
83
CategoricalPhysical::U32 => { let $ca = $self.cat32().unwrap().physical(); $($body)* },
84
}
85
},
86
87
_ => unreachable!(),
88
}
89
}}
90
}
91
pub(crate) use downcast_single_key_ca;
92
93
/// Represents a DataFrame plus a hash per row, intended for keys in grouping
94
/// or joining. The hashes may or may not actually be physically pre-computed,
95
/// this depends per type.
96
#[derive(Clone, Debug)]
97
pub enum HashKeys {
98
RowEncoded(RowEncodedKeys),
99
Binview(BinviewKeys),
100
Single(SingleKeys),
101
}
102
103
impl HashKeys {
104
pub fn from_df(
105
df: &DataFrame,
106
random_state: PlRandomState,
107
null_is_valid: bool,
108
force_row_encoding: bool,
109
) -> Self {
110
let first_col_variant = hash_keys_variant_for_dtype(df[0].dtype());
111
let use_row_encoding = force_row_encoding
112
|| df.width() > 1
113
|| first_col_variant == HashKeysVariant::RowEncoded;
114
if use_row_encoding {
115
let keys = df.get_columns();
116
let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array();
117
118
if !null_is_valid {
119
let validities = keys
120
.iter()
121
.map(|c| c.as_materialized_series().rechunk_validity())
122
.collect_vec();
123
let combined = combine_validities_and_many(&validities);
124
keys_encoded.set_validity(combined);
125
}
126
127
// TODO: use vechash? Not supported yet for lists.
128
// let mut hashes = Vec::with_capacity(df.height());
129
// columns_to_hashes(df.get_columns(), Some(random_state), &mut hashes).unwrap();
130
131
let hashes = keys_encoded
132
.values_iter()
133
.map(|k| random_state.hash_one(k))
134
.collect();
135
Self::RowEncoded(RowEncodedKeys {
136
hashes: PrimitiveArray::from_vec(hashes),
137
keys: keys_encoded,
138
})
139
} else if first_col_variant == HashKeysVariant::Binview {
140
let keys = if let Ok(ca_str) = df[0].str() {
141
ca_str.as_binary()
142
} else {
143
df[0].binary().unwrap().clone()
144
};
145
let keys = keys.rechunk().downcast_as_array().clone();
146
147
let hashes = if keys.has_nulls() {
148
keys.iter()
149
.map(|opt_k| opt_k.map(|k| random_state.hash_one(k)).unwrap_or(0))
150
.collect()
151
} else {
152
keys.values_iter()
153
.map(|k| random_state.hash_one(k))
154
.collect()
155
};
156
157
Self::Binview(BinviewKeys {
158
hashes: PrimitiveArray::from_vec(hashes),
159
keys,
160
null_is_valid,
161
})
162
} else {
163
Self::Single(SingleKeys {
164
random_state,
165
keys: df[0].as_materialized_series().rechunk(),
166
null_is_valid,
167
})
168
}
169
}
170
171
pub fn len(&self) -> usize {
172
match self {
173
HashKeys::RowEncoded(s) => s.keys.len(),
174
HashKeys::Single(s) => s.keys.len(),
175
HashKeys::Binview(s) => s.keys.len(),
176
}
177
}
178
179
pub fn is_empty(&self) -> bool {
180
self.len() == 0
181
}
182
183
pub fn validity(&self) -> Option<&Bitmap> {
184
match self {
185
HashKeys::RowEncoded(s) => s.keys.validity(),
186
HashKeys::Single(s) => s.keys.chunks()[0].validity(),
187
HashKeys::Binview(s) => s.keys.validity(),
188
}
189
}
190
191
pub fn null_is_valid(&self) -> bool {
192
match self {
193
HashKeys::RowEncoded(_) => false,
194
HashKeys::Single(s) => s.null_is_valid,
195
HashKeys::Binview(s) => s.null_is_valid,
196
}
197
}
198
199
/// Calls f with the index of and hash of each element in this HashKeys.
200
///
201
/// If the element is null and null_is_valid is false the respective hash
202
/// will be None.
203
pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
204
match self {
205
HashKeys::RowEncoded(s) => s.for_each_hash(f),
206
HashKeys::Single(s) => s.for_each_hash(f),
207
HashKeys::Binview(s) => s.for_each_hash(f),
208
}
209
}
210
211
/// Calls f with the index of and hash of each element in the given
212
/// subset of indices of the HashKeys.
213
///
214
/// If the element is null and null_is_valid is false the respective hash
215
/// will be None.
216
///
217
/// # Safety
218
/// The indices in the subset must be in-bounds.
219
pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
220
&self,
221
subset: &[IdxSize],
222
f: F,
223
) {
224
match self {
225
HashKeys::RowEncoded(s) => s.for_each_hash_subset(subset, f),
226
HashKeys::Single(s) => s.for_each_hash_subset(subset, f),
227
HashKeys::Binview(s) => s.for_each_hash_subset(subset, f),
228
}
229
}
230
231
/// After this call partitions will be extended with the partition for each
232
/// hash. Nulls are assigned IdxSize::MAX or a specific partition depending
233
/// on whether partition_nulls is true.
234
pub fn gen_partitions(
235
&self,
236
partitioner: &HashPartitioner,
237
partitions: &mut Vec<IdxSize>,
238
partition_nulls: bool,
239
) {
240
unsafe {
241
let null_p = if partition_nulls | self.null_is_valid() {
242
partitioner.null_partition() as IdxSize
243
} else {
244
IdxSize::MAX
245
};
246
partitions.reserve(self.len());
247
self.for_each_hash(|_idx, opt_h| {
248
partitions.push_unchecked(
249
opt_h
250
.map(|h| partitioner.hash_to_partition(h) as IdxSize)
251
.unwrap_or(null_p),
252
);
253
});
254
}
255
}
256
257
/// After this call partition_idxs[p] will be extended with the indices of
258
/// hashes that belong to partition p, and the cardinality sketches are
259
/// updated accordingly.
260
pub fn gen_idxs_per_partition(
261
&self,
262
partitioner: &HashPartitioner,
263
partition_idxs: &mut [Vec<IdxSize>],
264
sketches: &mut [CardinalitySketch],
265
partition_nulls: bool,
266
) {
267
if sketches.is_empty() {
268
self.gen_idxs_per_partition_impl::<false>(
269
partitioner,
270
partition_idxs,
271
sketches,
272
partition_nulls | self.null_is_valid(),
273
);
274
} else {
275
self.gen_idxs_per_partition_impl::<true>(
276
partitioner,
277
partition_idxs,
278
sketches,
279
partition_nulls | self.null_is_valid(),
280
);
281
}
282
}
283
284
fn gen_idxs_per_partition_impl<const BUILD_SKETCHES: bool>(
285
&self,
286
partitioner: &HashPartitioner,
287
partition_idxs: &mut [Vec<IdxSize>],
288
sketches: &mut [CardinalitySketch],
289
partition_nulls: bool,
290
) {
291
assert!(partition_idxs.len() == partitioner.num_partitions());
292
assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions());
293
294
let null_p = partitioner.null_partition();
295
self.for_each_hash(|idx, opt_h| {
296
if let Some(h) = opt_h {
297
unsafe {
298
// SAFETY: we assured the number of partitions matches.
299
let p = partitioner.hash_to_partition(h);
300
partition_idxs.get_unchecked_mut(p).push(idx);
301
if BUILD_SKETCHES {
302
sketches.get_unchecked_mut(p).insert(h);
303
}
304
}
305
} else if partition_nulls {
306
unsafe {
307
partition_idxs.get_unchecked_mut(null_p).push(idx);
308
}
309
}
310
});
311
}
312
313
pub fn sketch_cardinality(&self, sketch: &mut CardinalitySketch) {
314
self.for_each_hash(|_idx, opt_h| {
315
sketch.insert(opt_h.unwrap_or(0));
316
})
317
}
318
319
/// # Safety
320
/// The indices must be in-bounds.
321
pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
322
match self {
323
HashKeys::RowEncoded(s) => Self::RowEncoded(s.gather_unchecked(idxs)),
324
HashKeys::Single(s) => Self::Single(s.gather_unchecked(idxs)),
325
HashKeys::Binview(s) => Self::Binview(s.gather_unchecked(idxs)),
326
}
327
}
328
}
329
330
#[derive(Clone, Debug)]
331
pub struct RowEncodedKeys {
332
pub hashes: UInt64Array, // Always non-null, we use the validity of keys.
333
pub keys: BinaryArray<i64>,
334
}
335
336
impl RowEncodedKeys {
337
pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
338
for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
339
}
340
341
/// # Safety
342
/// The indices must be in-bounds.
343
pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
344
&self,
345
subset: &[IdxSize],
346
f: F,
347
) {
348
for_each_hash_subset_prehashed(
349
self.hashes.values().as_slice(),
350
self.keys.validity(),
351
subset,
352
f,
353
);
354
}
355
356
/// # Safety
357
/// The indices must be in-bounds.
358
pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
359
let idx_arr = arrow::ffi::mmap::slice(idxs);
360
Self {
361
hashes: polars_compute::gather::primitive::take_primitive_unchecked(
362
&self.hashes,
363
&idx_arr,
364
),
365
keys: polars_compute::gather::binary::take_unchecked(&self.keys, &idx_arr),
366
}
367
}
368
}
369
370
/// Single keys without prehashing.
371
#[derive(Clone, Debug)]
372
pub struct SingleKeys {
373
pub random_state: PlRandomState,
374
pub keys: Series,
375
pub null_is_valid: bool,
376
}
377
378
impl SingleKeys {
379
pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
380
downcast_single_key_ca!(self.keys, |keys| {
381
for_each_hash_single(keys, &self.random_state, f);
382
})
383
}
384
385
/// # Safety
386
/// The indices must be in-bounds.
387
pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
388
&self,
389
subset: &[IdxSize],
390
f: F,
391
) {
392
downcast_single_key_ca!(self.keys, |keys| {
393
for_each_hash_subset_single(keys, subset, &self.random_state, f);
394
})
395
}
396
397
/// # Safety
398
/// The indices must be in-bounds.
399
pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
400
Self {
401
random_state: self.random_state,
402
keys: self.keys.take_slice_unchecked(idxs),
403
null_is_valid: self.null_is_valid,
404
}
405
}
406
}
407
408
/// Pre-hashed binary view keys with prehashing.
409
#[derive(Clone, Debug)]
410
pub struct BinviewKeys {
411
pub hashes: UInt64Array,
412
pub keys: BinaryViewArray,
413
pub null_is_valid: bool,
414
}
415
416
impl BinviewKeys {
417
pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
418
for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
419
}
420
421
/// # Safety
422
/// The indices must be in-bounds.
423
pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
424
&self,
425
subset: &[IdxSize],
426
f: F,
427
) {
428
for_each_hash_subset_prehashed(
429
self.hashes.values().as_slice(),
430
self.keys.validity(),
431
subset,
432
f,
433
);
434
}
435
436
/// # Safety
437
/// The indices must be in-bounds.
438
pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
439
let idx_arr = arrow::ffi::mmap::slice(idxs);
440
Self {
441
hashes: polars_compute::gather::primitive::take_primitive_unchecked(
442
&self.hashes,
443
&idx_arr,
444
),
445
keys: polars_compute::gather::binview::take_binview_unchecked(&self.keys, &idx_arr),
446
null_is_valid: self.null_is_valid,
447
}
448
}
449
}
450
451
fn for_each_hash_prehashed<F: FnMut(IdxSize, Option<u64>)>(
452
hashes: &[u64],
453
opt_v: Option<&Bitmap>,
454
mut f: F,
455
) {
456
if let Some(validity) = opt_v {
457
for (idx, (is_v, hash)) in validity.iter().zip(hashes).enumerate_idx() {
458
if is_v {
459
f(idx, Some(*hash))
460
} else {
461
f(idx, None)
462
}
463
}
464
} else {
465
for (idx, h) in hashes.iter().enumerate_idx() {
466
f(idx, Some(*h));
467
}
468
}
469
}
470
471
/// # Safety
472
/// The indices must be in-bounds.
473
unsafe fn for_each_hash_subset_prehashed<F: FnMut(IdxSize, Option<u64>)>(
474
hashes: &[u64],
475
opt_v: Option<&Bitmap>,
476
subset: &[IdxSize],
477
mut f: F,
478
) {
479
if let Some(validity) = opt_v {
480
for idx in subset {
481
let hash = *hashes.get_unchecked(*idx as usize);
482
let is_v = validity.get_bit_unchecked(*idx as usize);
483
if is_v {
484
f(*idx, Some(hash))
485
} else {
486
f(*idx, None)
487
}
488
}
489
} else {
490
for idx in subset {
491
f(*idx, Some(*hashes.get_unchecked(*idx as usize)));
492
}
493
}
494
}
495
496
pub fn for_each_hash_single<T, F>(keys: &ChunkedArray<T>, random_state: &PlRandomState, mut f: F)
497
where
498
T: PolarsDataType,
499
for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
500
F: FnMut(IdxSize, Option<u64>),
501
{
502
let mut idx = 0;
503
if keys.has_nulls() {
504
for arr in keys.downcast_iter() {
505
for opt_k in arr.iter() {
506
f(idx, opt_k.map(|k| random_state.tot_hash_one(k)));
507
idx += 1;
508
}
509
}
510
} else {
511
for arr in keys.downcast_iter() {
512
for k in arr.values_iter() {
513
f(idx, Some(random_state.tot_hash_one(k)));
514
idx += 1;
515
}
516
}
517
}
518
}
519
520
/// # Safety
521
/// The indices must be in-bounds.
522
unsafe fn for_each_hash_subset_single<T, F>(
523
keys: &ChunkedArray<T>,
524
subset: &[IdxSize],
525
random_state: &PlRandomState,
526
mut f: F,
527
) where
528
T: PolarsDataType,
529
for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
530
F: FnMut(IdxSize, Option<u64>),
531
{
532
let keys_arr = keys.downcast_as_array();
533
534
if keys_arr.has_nulls() {
535
for idx in subset {
536
let opt_k = keys_arr.get_unchecked(*idx as usize);
537
f(*idx, opt_k.map(|k| random_state.tot_hash_one(k)));
538
}
539
} else {
540
for idx in subset {
541
let k = keys_arr.value_unchecked(*idx as usize);
542
f(*idx, Some(random_state.tot_hash_one(k)));
543
}
544
}
545
}
546
547