Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-parquet/src/arrow/write/dictionary.rs
6940 views
1
use arrow::array::{
2
Array, BinaryViewArray, DictionaryArray, DictionaryKey, PrimitiveArray, Utf8ViewArray,
3
};
4
use arrow::bitmap::{Bitmap, MutableBitmap};
5
use arrow::buffer::Buffer;
6
use arrow::datatypes::{ArrowDataType, IntegerType, PhysicalType};
7
use arrow::legacy::utils::CustomIterTools;
8
use arrow::trusted_len::TrustMyLength;
9
use arrow::types::NativeType;
10
use polars_compute::min_max::MinMaxKernel;
11
use polars_error::{PolarsResult, polars_bail};
12
13
use super::binary::{
14
build_statistics as binary_build_statistics, encode_plain as binary_encode_plain,
15
};
16
use super::fixed_size_binary::{
17
build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain,
18
};
19
use super::pages::PrimitiveNested;
20
use super::primitive::{
21
build_statistics as primitive_build_statistics, encode_plain as primitive_encode_plain,
22
};
23
use super::{EncodeNullability, Nested, WriteOptions, binview, nested};
24
use crate::arrow::read::schema::is_nullable;
25
use crate::arrow::write::{slice_nested_leaf, utils};
26
use crate::parquet::CowBuffer;
27
use crate::parquet::encoding::Encoding;
28
use crate::parquet::encoding::hybrid_rle::encode;
29
use crate::parquet::page::{DictPage, Page};
30
use crate::parquet::schema::types::PrimitiveType;
31
use crate::parquet::statistics::ParquetStatistics;
32
use crate::write::DynIter;
33
34
trait MinMaxThreshold {
35
const DELTA_THRESHOLD: usize;
36
const BITMASK_THRESHOLD: usize;
37
38
fn from_start_and_offset(start: Self, offset: usize) -> Self;
39
}
40
41
macro_rules! minmaxthreshold_impls {
42
($($signed:ty, $unsigned:ty => $threshold:literal, $bm_threshold:expr,)+) => {
43
$(
44
impl MinMaxThreshold for $signed {
45
const DELTA_THRESHOLD: usize = $threshold;
46
const BITMASK_THRESHOLD: usize = $bm_threshold;
47
48
fn from_start_and_offset(start: Self, offset: usize) -> Self {
49
start + ((offset as $unsigned) as $signed)
50
}
51
}
52
impl MinMaxThreshold for $unsigned {
53
const DELTA_THRESHOLD: usize = $threshold;
54
const BITMASK_THRESHOLD: usize = $bm_threshold;
55
56
fn from_start_and_offset(start: Self, offset: usize) -> Self {
57
start + (offset as $unsigned)
58
}
59
}
60
)+
61
};
62
}
63
64
minmaxthreshold_impls! {
65
i8, u8 => 16, u8::MAX as usize,
66
i16, u16 => 256, u16::MAX as usize,
67
i32, u32 => 512, u16::MAX as usize,
68
i64, u64 => 2048, u16::MAX as usize,
69
}
70
71
enum DictionaryDecision {
72
NotWorth,
73
TryAgain,
74
Found(DictionaryArray<u32>),
75
}
76
77
fn min_max_integer_encode_as_dictionary_optional<'a, E, T>(
78
array: &'a dyn Array,
79
) -> DictionaryDecision
80
where
81
E: std::fmt::Debug,
82
T: NativeType
83
+ MinMaxThreshold
84
+ std::cmp::Ord
85
+ TryInto<u32, Error = E>
86
+ std::ops::Sub<T, Output = T>
87
+ num_traits::CheckedSub
88
+ num_traits::cast::AsPrimitive<usize>,
89
std::ops::RangeInclusive<T>: Iterator<Item = T>,
90
PrimitiveArray<T>: MinMaxKernel<Scalar<'a> = T>,
91
{
92
let min_max = <PrimitiveArray<T> as MinMaxKernel>::min_max_ignore_nan_kernel(
93
array.as_any().downcast_ref().unwrap(),
94
);
95
96
let Some((min, max)) = min_max else {
97
return DictionaryDecision::TryAgain;
98
};
99
100
debug_assert!(max >= min, "{max} >= {min}");
101
let Some(diff) = max.checked_sub(&min) else {
102
return DictionaryDecision::TryAgain;
103
};
104
105
let diff = diff.as_();
106
107
if diff > T::BITMASK_THRESHOLD {
108
return DictionaryDecision::TryAgain;
109
}
110
111
let mut seen_mask = MutableBitmap::from_len_zeroed(diff + 1);
112
113
let array = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
114
115
if array.has_nulls() {
116
for v in array.non_null_values_iter() {
117
let offset = (v - min).as_();
118
debug_assert!(offset <= diff);
119
120
unsafe {
121
seen_mask.set_unchecked(offset, true);
122
}
123
}
124
} else {
125
for v in array.values_iter() {
126
let offset = (*v - min).as_();
127
debug_assert!(offset <= diff);
128
129
unsafe {
130
seen_mask.set_unchecked(offset, true);
131
}
132
}
133
}
134
135
let cardinality = seen_mask.set_bits();
136
137
let mut is_worth_it = false;
138
139
is_worth_it |= cardinality <= T::DELTA_THRESHOLD;
140
is_worth_it |= (cardinality as f64) / (array.len() as f64) < 0.75;
141
142
if !is_worth_it {
143
return DictionaryDecision::NotWorth;
144
}
145
146
let seen_mask = seen_mask.freeze();
147
148
// SAFETY: We just did the calculation for this.
149
let indexes = seen_mask
150
.true_idx_iter()
151
.map(|idx| T::from_start_and_offset(min, idx));
152
let indexes = unsafe { TrustMyLength::new(indexes, cardinality) };
153
let indexes = indexes.collect_trusted::<Vec<_>>();
154
155
let mut lookup = vec![0u16; diff + 1];
156
157
for (i, &idx) in indexes.iter().enumerate() {
158
lookup[(idx - min).as_()] = i as u16;
159
}
160
161
use ArrowDataType as DT;
162
let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), indexes.into(), None);
163
let values = Box::new(values);
164
165
let keys: Buffer<u32> = array
166
.as_any()
167
.downcast_ref::<PrimitiveArray<T>>()
168
.unwrap()
169
.values()
170
.iter()
171
.map(|v| {
172
// @NOTE:
173
// Since the values might contain nulls which have a undefined value. We just
174
// clamp the values to between the min and max value. This way, they will still
175
// be valid dictionary keys.
176
let idx = *v.clamp(&min, &max) - min;
177
let value = unsafe { lookup.get_unchecked(idx.as_()) };
178
(*value).into()
179
})
180
.collect();
181
182
let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned());
183
DictionaryDecision::Found(
184
DictionaryArray::<u32>::try_new(
185
ArrowDataType::Dictionary(
186
IntegerType::UInt32,
187
Box::new(DT::from(T::PRIMITIVE)),
188
false, // @TODO: This might be able to be set to true?
189
),
190
keys,
191
values,
192
)
193
.unwrap(),
194
)
195
}
196
197
pub(crate) fn encode_as_dictionary_optional(
198
array: &dyn Array,
199
nested: &[Nested],
200
type_: PrimitiveType,
201
options: WriteOptions,
202
) -> Option<PolarsResult<DynIter<'static, PolarsResult<Page>>>> {
203
if array.is_empty() {
204
let array = DictionaryArray::<u32>::new_empty(ArrowDataType::Dictionary(
205
IntegerType::UInt32,
206
Box::new(array.dtype().clone()),
207
false, // @TODO: This might be able to be set to true?
208
));
209
210
return Some(array_to_pages(
211
&array,
212
type_,
213
nested,
214
options,
215
Encoding::RleDictionary,
216
));
217
}
218
219
use arrow::types::PrimitiveType as PT;
220
let fast_dictionary = match array.dtype().to_physical_type() {
221
PhysicalType::Primitive(pt) => match pt {
222
PT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array),
223
PT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array),
224
PT::Int32 => min_max_integer_encode_as_dictionary_optional::<_, i32>(array),
225
PT::Int64 => min_max_integer_encode_as_dictionary_optional::<_, i64>(array),
226
PT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array),
227
PT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array),
228
PT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array),
229
PT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array),
230
_ => DictionaryDecision::TryAgain,
231
},
232
_ => DictionaryDecision::TryAgain,
233
};
234
235
match fast_dictionary {
236
DictionaryDecision::NotWorth => return None,
237
DictionaryDecision::Found(dictionary_array) => {
238
return Some(array_to_pages(
239
&dictionary_array,
240
type_,
241
nested,
242
options,
243
Encoding::RleDictionary,
244
));
245
},
246
DictionaryDecision::TryAgain => {},
247
}
248
249
let dtype = Box::new(array.dtype().clone());
250
251
let estimated_cardinality = polars_compute::cardinality::estimate_cardinality(array);
252
253
if array.len() > 128 && (estimated_cardinality as f64) / (array.len() as f64) > 0.75 {
254
return None;
255
}
256
257
// This does the group by.
258
let array = polars_compute::cast::cast(
259
array,
260
&ArrowDataType::Dictionary(IntegerType::UInt32, dtype, false),
261
Default::default(),
262
)
263
.ok()?;
264
265
let array = array
266
.as_any()
267
.downcast_ref::<DictionaryArray<u32>>()
268
.unwrap();
269
270
Some(array_to_pages(
271
array,
272
type_,
273
nested,
274
options,
275
Encoding::RleDictionary,
276
))
277
}
278
279
fn serialize_def_levels_simple(
280
validity: Option<&Bitmap>,
281
length: usize,
282
is_optional: bool,
283
options: WriteOptions,
284
buffer: &mut Vec<u8>,
285
) -> PolarsResult<()> {
286
utils::write_def_levels(buffer, is_optional, validity, length, options.version)
287
}
288
289
fn serialize_keys_values<K: DictionaryKey>(
290
array: &DictionaryArray<K>,
291
validity: Option<&Bitmap>,
292
buffer: &mut Vec<u8>,
293
) -> PolarsResult<()> {
294
let keys = array.keys_values_iter().map(|x| x as u32);
295
if let Some(validity) = validity {
296
// discard indices whose values are null.
297
let keys = keys
298
.zip(validity.iter())
299
.filter(|&(_key, is_valid)| is_valid)
300
.map(|(key, _is_valid)| key);
301
let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64);
302
303
let keys = utils::ExactSizedIter::new(keys, array.len() - validity.unset_bits());
304
305
// num_bits as a single byte
306
buffer.push(num_bits as u8);
307
308
// followed by the encoded indices.
309
Ok(encode::<u32, _, _>(buffer, keys, num_bits)?)
310
} else {
311
let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64);
312
313
// num_bits as a single byte
314
buffer.push(num_bits as u8);
315
316
// followed by the encoded indices.
317
Ok(encode::<u32, _, _>(buffer, keys, num_bits)?)
318
}
319
}
320
321
fn serialize_levels(
322
validity: Option<&Bitmap>,
323
length: usize,
324
type_: &PrimitiveType,
325
nested: &[Nested],
326
options: WriteOptions,
327
buffer: &mut Vec<u8>,
328
) -> PolarsResult<(usize, usize)> {
329
if nested.len() == 1 {
330
let is_optional = is_nullable(&type_.field_info);
331
serialize_def_levels_simple(validity, length, is_optional, options, buffer)?;
332
let definition_levels_byte_length = buffer.len();
333
Ok((0, definition_levels_byte_length))
334
} else {
335
nested::write_rep_and_def(options.version, nested, buffer)
336
}
337
}
338
339
fn normalized_validity<K: DictionaryKey>(array: &DictionaryArray<K>) -> Option<Bitmap> {
340
match (array.keys().validity(), array.values().validity()) {
341
(None, None) => None,
342
(keys, None) => keys.cloned(),
343
// The values can have a different length than the keys
344
(_, Some(_values)) => {
345
let iter = (0..array.len()).map(|i| unsafe { !array.is_null_unchecked(i) });
346
MutableBitmap::from_trusted_len_iter(iter).into()
347
},
348
}
349
}
350
351
fn serialize_keys<K: DictionaryKey>(
352
array: &DictionaryArray<K>,
353
type_: PrimitiveType,
354
nested: &[Nested],
355
statistics: Option<ParquetStatistics>,
356
options: WriteOptions,
357
) -> PolarsResult<Page> {
358
let mut buffer = vec![];
359
360
let (start, len) = slice_nested_leaf(nested);
361
362
let mut nested = nested.to_vec();
363
let array = array.clone().sliced(start, len);
364
if let Some(Nested::Primitive(PrimitiveNested { length, .. })) = nested.last_mut() {
365
*length = len;
366
} else {
367
unreachable!("")
368
}
369
// Parquet only accepts a single validity - we "&" the validities into a single one
370
// and ignore keys whose _value_ is null.
371
// It's important that we slice before normalizing.
372
let validity = normalized_validity(&array);
373
374
let (repetition_levels_byte_length, definition_levels_byte_length) = serialize_levels(
375
validity.as_ref(),
376
array.len(),
377
&type_,
378
&nested,
379
options,
380
&mut buffer,
381
)?;
382
383
serialize_keys_values(&array, validity.as_ref(), &mut buffer)?;
384
385
let (num_values, num_rows) = if nested.len() == 1 {
386
(array.len(), array.len())
387
} else {
388
(nested::num_values(&nested), nested[0].len())
389
};
390
391
utils::build_plain_page(
392
buffer,
393
num_values,
394
num_rows,
395
array.null_count(),
396
repetition_levels_byte_length,
397
definition_levels_byte_length,
398
statistics,
399
type_,
400
options,
401
Encoding::RleDictionary,
402
)
403
.map(Page::Data)
404
}
405
406
macro_rules! dyn_prim {
407
($from:ty, $to:ty, $array:expr, $options:expr, $type_:expr) => {{
408
let values = $array.values().as_any().downcast_ref().unwrap();
409
410
let buffer =
411
primitive_encode_plain::<$from, $to>(values, EncodeNullability::new(false), vec![]);
412
413
let stats: Option<ParquetStatistics> = if !$options.statistics.is_empty() {
414
let mut stats = primitive_build_statistics::<$from, $to>(
415
values,
416
$type_.clone(),
417
&$options.statistics,
418
);
419
stats.null_count = Some($array.null_count() as i64);
420
Some(stats.serialize())
421
} else {
422
None
423
};
424
(
425
DictPage::new(CowBuffer::Owned(buffer), values.len(), false),
426
stats,
427
)
428
}};
429
}
430
431
pub fn array_to_pages<K: DictionaryKey>(
432
array: &DictionaryArray<K>,
433
type_: PrimitiveType,
434
nested: &[Nested],
435
options: WriteOptions,
436
encoding: Encoding,
437
) -> PolarsResult<DynIter<'static, PolarsResult<Page>>> {
438
match encoding {
439
Encoding::PlainDictionary | Encoding::RleDictionary => {
440
// write DictPage
441
let (dict_page, mut statistics): (_, Option<ParquetStatistics>) = match array
442
.values()
443
.dtype()
444
.to_logical_type()
445
{
446
ArrowDataType::Int8 => dyn_prim!(i8, i32, array, options, type_),
447
ArrowDataType::Int16 => dyn_prim!(i16, i32, array, options, type_),
448
ArrowDataType::Int32 | ArrowDataType::Date32 | ArrowDataType::Time32(_) => {
449
dyn_prim!(i32, i32, array, options, type_)
450
},
451
ArrowDataType::Int64
452
| ArrowDataType::Date64
453
| ArrowDataType::Time64(_)
454
| ArrowDataType::Timestamp(_, _)
455
| ArrowDataType::Duration(_) => dyn_prim!(i64, i64, array, options, type_),
456
ArrowDataType::UInt8 => dyn_prim!(u8, i32, array, options, type_),
457
ArrowDataType::UInt16 => dyn_prim!(u16, i32, array, options, type_),
458
ArrowDataType::UInt32 => dyn_prim!(u32, i32, array, options, type_),
459
ArrowDataType::UInt64 => dyn_prim!(u64, i64, array, options, type_),
460
ArrowDataType::Float32 => dyn_prim!(f32, f32, array, options, type_),
461
ArrowDataType::Float64 => dyn_prim!(f64, f64, array, options, type_),
462
ArrowDataType::LargeUtf8 => {
463
let array = polars_compute::cast::cast(
464
array.values().as_ref(),
465
&ArrowDataType::LargeBinary,
466
Default::default(),
467
)
468
.unwrap();
469
let array = array.as_any().downcast_ref().unwrap();
470
471
let mut buffer = vec![];
472
binary_encode_plain::<i64>(array, EncodeNullability::Required, &mut buffer);
473
let stats = if options.has_statistics() {
474
Some(binary_build_statistics(
475
array,
476
type_.clone(),
477
&options.statistics,
478
))
479
} else {
480
None
481
};
482
(
483
DictPage::new(CowBuffer::Owned(buffer), array.len(), false),
484
stats,
485
)
486
},
487
ArrowDataType::BinaryView => {
488
let array = array
489
.values()
490
.as_any()
491
.downcast_ref::<BinaryViewArray>()
492
.unwrap();
493
let mut buffer = vec![];
494
binview::encode_plain(array, EncodeNullability::Required, &mut buffer);
495
496
let stats = if options.has_statistics() {
497
Some(binview::build_statistics(
498
array,
499
type_.clone(),
500
&options.statistics,
501
))
502
} else {
503
None
504
};
505
(
506
DictPage::new(CowBuffer::Owned(buffer), array.len(), false),
507
stats,
508
)
509
},
510
ArrowDataType::Utf8View => {
511
let array = array
512
.values()
513
.as_any()
514
.downcast_ref::<Utf8ViewArray>()
515
.unwrap()
516
.to_binview();
517
let mut buffer = vec![];
518
binview::encode_plain(&array, EncodeNullability::Required, &mut buffer);
519
520
let stats = if options.has_statistics() {
521
Some(binview::build_statistics(
522
&array,
523
type_.clone(),
524
&options.statistics,
525
))
526
} else {
527
None
528
};
529
(
530
DictPage::new(CowBuffer::Owned(buffer), array.len(), false),
531
stats,
532
)
533
},
534
ArrowDataType::LargeBinary => {
535
let values = array.values().as_any().downcast_ref().unwrap();
536
537
let mut buffer = vec![];
538
binary_encode_plain::<i64>(values, EncodeNullability::Required, &mut buffer);
539
let stats = if options.has_statistics() {
540
Some(binary_build_statistics(
541
values,
542
type_.clone(),
543
&options.statistics,
544
))
545
} else {
546
None
547
};
548
(
549
DictPage::new(CowBuffer::Owned(buffer), values.len(), false),
550
stats,
551
)
552
},
553
ArrowDataType::FixedSizeBinary(_) => {
554
let mut buffer = vec![];
555
let array = array.values().as_any().downcast_ref().unwrap();
556
fixed_binary_encode_plain(array, EncodeNullability::Required, &mut buffer);
557
let stats = if options.has_statistics() {
558
let stats = fixed_binary_build_statistics(
559
array,
560
type_.clone(),
561
&options.statistics,
562
);
563
Some(stats.serialize())
564
} else {
565
None
566
};
567
(
568
DictPage::new(CowBuffer::Owned(buffer), array.len(), false),
569
stats,
570
)
571
},
572
other => {
573
polars_bail!(
574
nyi =
575
"Writing dictionary arrays to parquet only support data type {other:?}"
576
)
577
},
578
};
579
580
if let Some(stats) = &mut statistics {
581
stats.null_count = Some(array.null_count() as i64)
582
}
583
584
// write DataPage pointing to DictPage
585
let data_page =
586
serialize_keys(array, type_, nested, statistics, options)?.unwrap_data();
587
588
Ok(DynIter::new(
589
[Page::Dict(dict_page), Page::Data(data_page)]
590
.into_iter()
591
.map(Ok),
592
))
593
},
594
_ => polars_bail!(nyi = "Dictionary arrays only support dictionary encoding"),
595
}
596
}
597
598