Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/hash_keys.rs
8418 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::hash::BuildHasher;
3
4
use arrow::array::{Array, BinaryArray, BinaryViewArray, PrimitiveArray, StaticArray, UInt64Array};
5
use arrow::bitmap::Bitmap;
6
use arrow::compute::utils::combine_validities_and_many;
7
use polars_core::frame::DataFrame;
8
use polars_core::prelude::row_encode::_get_rows_encoded_unordered;
9
use polars_core::prelude::{ChunkedArray, DataType, PlRandomState, PolarsDataType, *};
10
use polars_core::series::Series;
11
use polars_utils::IdxSize;
12
use polars_utils::cardinality_sketch::CardinalitySketch;
13
use polars_utils::hashing::HashPartitioner;
14
use polars_utils::itertools::Itertools;
15
use polars_utils::total_ord::{BuildHasherTotalExt, TotalHash};
16
use polars_utils::vec::PushUnchecked;
17
18
#[derive(PartialEq, Eq, PartialOrd, Ord)]
19
pub enum HashKeysVariant {
20
RowEncoded,
21
Single,
22
Binview,
23
}
24
25
pub fn hash_keys_variant_for_dtype(dt: &DataType) -> HashKeysVariant {
26
match dt {
27
dt if dt.is_primitive_numeric() | dt.is_temporal() => HashKeysVariant::Single,
28
29
#[cfg(feature = "dtype-decimal")]
30
DataType::Decimal(_, _) => HashKeysVariant::Single,
31
#[cfg(feature = "dtype-categorical")]
32
DataType::Enum(_, _) | DataType::Categorical(_, _) => HashKeysVariant::Single,
33
34
DataType::String | DataType::Binary => HashKeysVariant::Binview,
35
36
// TODO: more efficient encoding for these.
37
DataType::Boolean | DataType::Null => HashKeysVariant::RowEncoded,
38
39
_ => HashKeysVariant::RowEncoded,
40
}
41
}
42
43
macro_rules! downcast_single_key_ca {
44
(
45
$self:expr, | $ca:ident | $($body:tt)*
46
) => {{
47
#[allow(unused_imports)]
48
use polars_core::datatypes::DataType::*;
49
match $self.dtype() {
50
#[cfg(feature = "dtype-i8")]
51
DataType::Int8 => { let $ca = $self.i8().unwrap(); $($body)* },
52
#[cfg(feature = "dtype-i16")]
53
DataType::Int16 => { let $ca = $self.i16().unwrap(); $($body)* },
54
DataType::Int32 => { let $ca = $self.i32().unwrap(); $($body)* },
55
DataType::Int64 => { let $ca = $self.i64().unwrap(); $($body)* },
56
#[cfg(feature = "dtype-u8")]
57
DataType::UInt8 => { let $ca = $self.u8().unwrap(); $($body)* },
58
#[cfg(feature = "dtype-u16")]
59
DataType::UInt16 => { let $ca = $self.u16().unwrap(); $($body)* },
60
DataType::UInt32 => { let $ca = $self.u32().unwrap(); $($body)* },
61
DataType::UInt64 => { let $ca = $self.u64().unwrap(); $($body)* },
62
#[cfg(feature = "dtype-i128")]
63
DataType::Int128 => { let $ca = $self.i128().unwrap(); $($body)* },
64
#[cfg(feature = "dtype-u128")]
65
DataType::UInt128 => { let $ca = $self.u128().unwrap(); $($body)* },
66
#[cfg(feature = "dtype-f16")]
67
DataType::Float16 => { let $ca = $self.f16().unwrap(); $($body)* },
68
DataType::Float32 => { let $ca = $self.f32().unwrap(); $($body)* },
69
DataType::Float64 => { let $ca = $self.f64().unwrap(); $($body)* },
70
71
#[cfg(feature = "dtype-date")]
72
DataType::Date => { let $ca = $self.date().unwrap().physical(); $($body)* },
73
#[cfg(feature = "dtype-time")]
74
DataType::Time => { let $ca = $self.time().unwrap().physical(); $($body)* },
75
#[cfg(feature = "dtype-datetime")]
76
DataType::Datetime(..) => { let $ca = $self.datetime().unwrap().physical(); $($body)* },
77
#[cfg(feature = "dtype-duration")]
78
DataType::Duration(..) => { let $ca = $self.duration().unwrap().physical(); $($body)* },
79
80
#[cfg(feature = "dtype-decimal")]
81
DataType::Decimal(..) => { let $ca = $self.decimal().unwrap().physical(); $($body)* },
82
#[cfg(feature = "dtype-categorical")]
83
dt @ (DataType::Enum(_, _) | DataType::Categorical(_, _)) => {
84
match dt.cat_physical().unwrap() {
85
CategoricalPhysical::U8 => { let $ca = $self.cat8().unwrap().physical(); $($body)* },
86
CategoricalPhysical::U16 => { let $ca = $self.cat16().unwrap().physical(); $($body)* },
87
CategoricalPhysical::U32 => { let $ca = $self.cat32().unwrap().physical(); $($body)* },
88
}
89
},
90
91
_ => unreachable!(),
92
}
93
}}
94
}
95
96
/// Represents a DataFrame plus a hash per row, intended for keys in grouping
97
/// or joining. The hashes may or may not actually be physically pre-computed,
98
/// this depends per type.
99
#[derive(Clone, Debug)]
100
pub enum HashKeys {
101
RowEncoded(RowEncodedKeys),
102
Binview(BinviewKeys),
103
Single(SingleKeys),
104
}
105
106
impl HashKeys {
107
pub fn from_df(
108
df: &DataFrame,
109
random_state: PlRandomState,
110
null_is_valid: bool,
111
force_row_encoding: bool,
112
) -> Self {
113
let first_col_variant = hash_keys_variant_for_dtype(df[0].dtype());
114
let use_row_encoding = force_row_encoding
115
|| df.width() > 1
116
|| first_col_variant == HashKeysVariant::RowEncoded;
117
if use_row_encoding {
118
let keys = df.columns();
119
let mut keys_encoded = _get_rows_encoded_unordered(keys).unwrap().into_array();
120
121
if !null_is_valid {
122
let validities = keys
123
.iter()
124
.map(|c| c.as_materialized_series().rechunk_validity())
125
.collect_vec();
126
let combined = combine_validities_and_many(&validities);
127
keys_encoded.set_validity(combined);
128
}
129
130
// TODO: use vechash? Not supported yet for lists.
131
// let mut hashes = Vec::with_capacity(df.height());
132
// columns_to_hashes(df.columns(), Some(random_state), &mut hashes).unwrap();
133
134
let hashes = keys_encoded
135
.values_iter()
136
.map(|k| random_state.hash_one(k))
137
.collect();
138
Self::RowEncoded(RowEncodedKeys {
139
hashes: PrimitiveArray::from_vec(hashes),
140
keys: keys_encoded,
141
})
142
} else if first_col_variant == HashKeysVariant::Binview {
143
let keys = if let Ok(ca_str) = df[0].str() {
144
ca_str.as_binary()
145
} else {
146
df[0].binary().unwrap().clone()
147
};
148
let keys = keys.rechunk().downcast_as_array().clone();
149
150
let hashes = if keys.has_nulls() {
151
keys.iter()
152
.map(|opt_k| opt_k.map(|k| random_state.hash_one(k)).unwrap_or(0))
153
.collect()
154
} else {
155
keys.values_iter()
156
.map(|k| random_state.hash_one(k))
157
.collect()
158
};
159
160
Self::Binview(BinviewKeys {
161
hashes: PrimitiveArray::from_vec(hashes),
162
keys,
163
null_is_valid,
164
})
165
} else {
166
Self::Single(SingleKeys {
167
random_state,
168
keys: df[0].as_materialized_series().rechunk(),
169
null_is_valid,
170
})
171
}
172
}
173
174
pub fn len(&self) -> usize {
175
match self {
176
HashKeys::RowEncoded(s) => s.keys.len(),
177
HashKeys::Single(s) => s.keys.len(),
178
HashKeys::Binview(s) => s.keys.len(),
179
}
180
}
181
182
pub fn is_empty(&self) -> bool {
183
self.len() == 0
184
}
185
186
pub fn validity(&self) -> Option<&Bitmap> {
187
match self {
188
HashKeys::RowEncoded(s) => s.keys.validity(),
189
HashKeys::Single(s) => s.keys.chunks()[0].validity(),
190
HashKeys::Binview(s) => s.keys.validity(),
191
}
192
}
193
194
pub fn null_is_valid(&self) -> bool {
195
match self {
196
HashKeys::RowEncoded(_) => false,
197
HashKeys::Single(s) => s.null_is_valid,
198
HashKeys::Binview(s) => s.null_is_valid,
199
}
200
}
201
202
/// Calls f with the index of and hash of each element in this HashKeys.
203
///
204
/// If the element is null and null_is_valid is false the respective hash
205
/// will be None.
206
pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
207
match self {
208
HashKeys::RowEncoded(s) => s.for_each_hash(f),
209
HashKeys::Single(s) => s.for_each_hash(f),
210
HashKeys::Binview(s) => s.for_each_hash(f),
211
}
212
}
213
214
/// Calls f with the index of and hash of each element in the given
215
/// subset of indices of the HashKeys.
216
///
217
/// If the element is null and null_is_valid is false the respective hash
218
/// will be None.
219
///
220
/// # Safety
221
/// The indices in the subset must be in-bounds.
222
pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
223
&self,
224
subset: &[IdxSize],
225
f: F,
226
) {
227
match self {
228
HashKeys::RowEncoded(s) => s.for_each_hash_subset(subset, f),
229
HashKeys::Single(s) => s.for_each_hash_subset(subset, f),
230
HashKeys::Binview(s) => s.for_each_hash_subset(subset, f),
231
}
232
}
233
234
/// After this call partitions will be extended with the partition for each
235
/// hash. Nulls are assigned IdxSize::MAX or a specific partition depending
236
/// on whether partition_nulls is true.
237
pub fn gen_partitions(
238
&self,
239
partitioner: &HashPartitioner,
240
partitions: &mut Vec<IdxSize>,
241
partition_nulls: bool,
242
) {
243
unsafe {
244
let null_p = if partition_nulls | self.null_is_valid() {
245
partitioner.null_partition() as IdxSize
246
} else {
247
IdxSize::MAX
248
};
249
partitions.reserve(self.len());
250
self.for_each_hash(|_idx, opt_h| {
251
partitions.push_unchecked(
252
opt_h
253
.map(|h| partitioner.hash_to_partition(h) as IdxSize)
254
.unwrap_or(null_p),
255
);
256
});
257
}
258
}
259
260
/// After this call partition_idxs[p] will be extended with the indices of
261
/// hashes that belong to partition p, and the cardinality sketches are
262
/// updated accordingly.
263
pub fn gen_idxs_per_partition(
264
&self,
265
partitioner: &HashPartitioner,
266
partition_idxs: &mut [Vec<IdxSize>],
267
sketches: &mut [CardinalitySketch],
268
partition_nulls: bool,
269
) {
270
if sketches.is_empty() {
271
self.gen_idxs_per_partition_impl::<false>(
272
partitioner,
273
partition_idxs,
274
sketches,
275
partition_nulls | self.null_is_valid(),
276
);
277
} else {
278
self.gen_idxs_per_partition_impl::<true>(
279
partitioner,
280
partition_idxs,
281
sketches,
282
partition_nulls | self.null_is_valid(),
283
);
284
}
285
}
286
287
fn gen_idxs_per_partition_impl<const BUILD_SKETCHES: bool>(
288
&self,
289
partitioner: &HashPartitioner,
290
partition_idxs: &mut [Vec<IdxSize>],
291
sketches: &mut [CardinalitySketch],
292
partition_nulls: bool,
293
) {
294
assert!(partition_idxs.len() == partitioner.num_partitions());
295
assert!(!BUILD_SKETCHES || sketches.len() == partitioner.num_partitions());
296
297
let null_p = partitioner.null_partition();
298
self.for_each_hash(|idx, opt_h| {
299
if let Some(h) = opt_h {
300
unsafe {
301
// SAFETY: we assured the number of partitions matches.
302
let p = partitioner.hash_to_partition(h);
303
partition_idxs.get_unchecked_mut(p).push(idx);
304
if BUILD_SKETCHES {
305
sketches.get_unchecked_mut(p).insert(h);
306
}
307
}
308
} else if partition_nulls {
309
unsafe {
310
partition_idxs.get_unchecked_mut(null_p).push(idx);
311
}
312
}
313
});
314
}
315
316
pub fn sketch_cardinality(&self, sketch: &mut CardinalitySketch) {
317
self.for_each_hash(|_idx, opt_h| {
318
sketch.insert(opt_h.unwrap_or(0));
319
})
320
}
321
322
/// # Safety
323
/// The indices must be in-bounds.
324
pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
325
match self {
326
HashKeys::RowEncoded(s) => Self::RowEncoded(s.gather_unchecked(idxs)),
327
HashKeys::Single(s) => Self::Single(s.gather_unchecked(idxs)),
328
HashKeys::Binview(s) => Self::Binview(s.gather_unchecked(idxs)),
329
}
330
}
331
}
332
333
#[derive(Clone, Debug)]
334
pub struct RowEncodedKeys {
335
pub hashes: UInt64Array, // Always non-null, we use the validity of keys.
336
pub keys: BinaryArray<i64>,
337
}
338
339
impl RowEncodedKeys {
340
pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
341
for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
342
}
343
344
/// # Safety
345
/// The indices must be in-bounds.
346
pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
347
&self,
348
subset: &[IdxSize],
349
f: F,
350
) {
351
for_each_hash_subset_prehashed(
352
self.hashes.values().as_slice(),
353
self.keys.validity(),
354
subset,
355
f,
356
);
357
}
358
359
/// # Safety
360
/// The indices must be in-bounds.
361
pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
362
let idx_arr = arrow::ffi::mmap::slice(idxs);
363
Self {
364
hashes: polars_compute::gather::primitive::take_primitive_unchecked(
365
&self.hashes,
366
&idx_arr,
367
),
368
keys: polars_compute::gather::binary::take_unchecked(&self.keys, &idx_arr),
369
}
370
}
371
}
372
373
/// Single keys without prehashing.
374
#[derive(Clone, Debug)]
375
pub struct SingleKeys {
376
pub random_state: PlRandomState,
377
pub keys: Series,
378
pub null_is_valid: bool,
379
}
380
381
impl SingleKeys {
382
pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
383
downcast_single_key_ca!(self.keys, |keys| {
384
for_each_hash_single(keys, &self.random_state, f);
385
})
386
}
387
388
/// # Safety
389
/// The indices must be in-bounds.
390
pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
391
&self,
392
subset: &[IdxSize],
393
f: F,
394
) {
395
downcast_single_key_ca!(self.keys, |keys| {
396
for_each_hash_subset_single(keys, subset, &self.random_state, f);
397
})
398
}
399
400
/// # Safety
401
/// The indices must be in-bounds.
402
pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
403
Self {
404
random_state: self.random_state.clone(),
405
keys: self.keys.take_slice_unchecked(idxs),
406
null_is_valid: self.null_is_valid,
407
}
408
}
409
}
410
411
/// Pre-hashed binary view keys with prehashing.
412
#[derive(Clone, Debug)]
413
pub struct BinviewKeys {
414
pub hashes: UInt64Array,
415
pub keys: BinaryViewArray,
416
pub null_is_valid: bool,
417
}
418
419
impl BinviewKeys {
420
pub fn for_each_hash<F: FnMut(IdxSize, Option<u64>)>(&self, f: F) {
421
for_each_hash_prehashed(self.hashes.values().as_slice(), self.keys.validity(), f);
422
}
423
424
/// # Safety
425
/// The indices must be in-bounds.
426
pub unsafe fn for_each_hash_subset<F: FnMut(IdxSize, Option<u64>)>(
427
&self,
428
subset: &[IdxSize],
429
f: F,
430
) {
431
for_each_hash_subset_prehashed(
432
self.hashes.values().as_slice(),
433
self.keys.validity(),
434
subset,
435
f,
436
);
437
}
438
439
/// # Safety
440
/// The indices must be in-bounds.
441
pub unsafe fn gather_unchecked(&self, idxs: &[IdxSize]) -> Self {
442
let idx_arr = arrow::ffi::mmap::slice(idxs);
443
Self {
444
hashes: polars_compute::gather::primitive::take_primitive_unchecked(
445
&self.hashes,
446
&idx_arr,
447
),
448
keys: polars_compute::gather::binview::take_binview_unchecked(&self.keys, &idx_arr),
449
null_is_valid: self.null_is_valid,
450
}
451
}
452
}
453
454
fn for_each_hash_prehashed<F: FnMut(IdxSize, Option<u64>)>(
455
hashes: &[u64],
456
opt_v: Option<&Bitmap>,
457
mut f: F,
458
) {
459
if let Some(validity) = opt_v {
460
for (idx, (is_v, hash)) in validity.iter().zip(hashes).enumerate_idx() {
461
if is_v {
462
f(idx, Some(*hash))
463
} else {
464
f(idx, None)
465
}
466
}
467
} else {
468
for (idx, h) in hashes.iter().enumerate_idx() {
469
f(idx, Some(*h));
470
}
471
}
472
}
473
474
/// # Safety
475
/// The indices must be in-bounds.
476
unsafe fn for_each_hash_subset_prehashed<F: FnMut(IdxSize, Option<u64>)>(
477
hashes: &[u64],
478
opt_v: Option<&Bitmap>,
479
subset: &[IdxSize],
480
mut f: F,
481
) {
482
if let Some(validity) = opt_v {
483
for idx in subset {
484
let hash = *hashes.get_unchecked(*idx as usize);
485
let is_v = validity.get_bit_unchecked(*idx as usize);
486
if is_v {
487
f(*idx, Some(hash))
488
} else {
489
f(*idx, None)
490
}
491
}
492
} else {
493
for idx in subset {
494
f(*idx, Some(*hashes.get_unchecked(*idx as usize)));
495
}
496
}
497
}
498
499
pub fn for_each_hash_single<T, F>(keys: &ChunkedArray<T>, random_state: &PlRandomState, mut f: F)
500
where
501
T: PolarsDataType,
502
for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
503
F: FnMut(IdxSize, Option<u64>),
504
{
505
let mut idx = 0;
506
if keys.has_nulls() {
507
for arr in keys.downcast_iter() {
508
for opt_k in arr.iter() {
509
f(idx, opt_k.map(|k| random_state.tot_hash_one(k)));
510
idx += 1;
511
}
512
}
513
} else {
514
for arr in keys.downcast_iter() {
515
for k in arr.values_iter() {
516
f(idx, Some(random_state.tot_hash_one(k)));
517
idx += 1;
518
}
519
}
520
}
521
}
522
523
/// # Safety
524
/// The indices must be in-bounds.
525
unsafe fn for_each_hash_subset_single<T, F>(
526
keys: &ChunkedArray<T>,
527
subset: &[IdxSize],
528
random_state: &PlRandomState,
529
mut f: F,
530
) where
531
T: PolarsDataType,
532
for<'a> <T as PolarsDataType>::Physical<'a>: TotalHash,
533
F: FnMut(IdxSize, Option<u64>),
534
{
535
let keys_arr = keys.downcast_as_array();
536
537
if keys_arr.has_nulls() {
538
for idx in subset {
539
let opt_k = keys_arr.get_unchecked(*idx as usize);
540
f(*idx, opt_k.map(|k| random_state.tot_hash_one(k)));
541
}
542
} else {
543
for idx in subset {
544
let k = keys_arr.value_unchecked(*idx as usize);
545
f(*idx, Some(random_state.tot_hash_one(k)));
546
}
547
}
548
}
549
550