Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-parquet/src/arrow/write/dictionary.rs
8475 views
1
use arrow::array::{
2
Array, BinaryViewArray, DictionaryArray, DictionaryKey, PrimitiveArray, Utf8ViewArray,
3
};
4
use arrow::bitmap::{Bitmap, MutableBitmap};
5
use arrow::compute::aggregate::estimated_bytes_size;
6
use arrow::datatypes::{ArrowDataType, IntegerType, PhysicalType};
7
use arrow::legacy::utils::CustomIterTools;
8
use arrow::trusted_len::TrustMyLength;
9
use arrow::types::NativeType;
10
use polars_buffer::Buffer;
11
use polars_compute::min_max::MinMaxKernel;
12
use polars_error::{PolarsResult, polars_bail};
13
use polars_utils::float16::pf16;
14
15
use super::binary::{
16
build_statistics as binary_build_statistics, encode_plain as binary_encode_plain,
17
};
18
use super::fixed_size_binary::{
19
build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain,
20
};
21
use super::primitive::{
22
build_statistics as primitive_build_statistics, encode_plain as primitive_encode_plain,
23
};
24
use super::{
25
EncodeNullability, Nested, WriteOptions, binview, nested, row_slice_ranges, slice_parquet_array,
26
};
27
use crate::arrow::read::schema::is_nullable;
28
use crate::arrow::write::utils;
29
use crate::parquet::CowBuffer;
30
use crate::parquet::encoding::Encoding;
31
use crate::parquet::encoding::hybrid_rle::encode;
32
use crate::parquet::page::{DictPage, Page};
33
use crate::parquet::schema::types::PrimitiveType;
34
use crate::parquet::statistics::ParquetStatistics;
35
use crate::write::DynIter;
36
37
trait MinMaxThreshold {
38
const DELTA_THRESHOLD: usize;
39
const BITMASK_THRESHOLD: usize;
40
41
fn from_start_and_offset(start: Self, offset: usize) -> Self;
42
}
43
44
macro_rules! minmaxthreshold_impls {
45
($($signed:ty, $unsigned:ty => $threshold:literal, $bm_threshold:expr,)+) => {
46
$(
47
impl MinMaxThreshold for $signed {
48
const DELTA_THRESHOLD: usize = $threshold;
49
const BITMASK_THRESHOLD: usize = $bm_threshold;
50
51
fn from_start_and_offset(start: Self, offset: usize) -> Self {
52
start + ((offset as $unsigned) as $signed)
53
}
54
}
55
impl MinMaxThreshold for $unsigned {
56
const DELTA_THRESHOLD: usize = $threshold;
57
const BITMASK_THRESHOLD: usize = $bm_threshold;
58
59
fn from_start_and_offset(start: Self, offset: usize) -> Self {
60
start + (offset as $unsigned)
61
}
62
}
63
)+
64
};
65
}
66
67
minmaxthreshold_impls! {
68
i8, u8 => 16, u8::MAX as usize,
69
i16, u16 => 256, u16::MAX as usize,
70
i32, u32 => 512, u16::MAX as usize,
71
i64, u64 => 2048, u16::MAX as usize,
72
}
73
74
enum DictionaryDecision {
75
NotWorth,
76
TryAgain,
77
Found(DictionaryArray<u32>),
78
}
79
80
fn min_max_integer_encode_as_dictionary_optional<'a, E, T>(
81
array: &'a dyn Array,
82
) -> DictionaryDecision
83
where
84
E: std::fmt::Debug,
85
T: NativeType
86
+ MinMaxThreshold
87
+ std::cmp::Ord
88
+ TryInto<u32, Error = E>
89
+ std::ops::Sub<T, Output = T>
90
+ num_traits::CheckedSub
91
+ num_traits::cast::AsPrimitive<usize>,
92
std::ops::RangeInclusive<T>: Iterator<Item = T>,
93
PrimitiveArray<T>: MinMaxKernel<Scalar<'a> = T>,
94
{
95
let min_max = <PrimitiveArray<T> as MinMaxKernel>::min_max_ignore_nan_kernel(
96
array.as_any().downcast_ref().unwrap(),
97
);
98
99
let Some((min, max)) = min_max else {
100
return DictionaryDecision::TryAgain;
101
};
102
103
debug_assert!(max >= min, "{max} >= {min}");
104
let Some(diff) = max.checked_sub(&min) else {
105
return DictionaryDecision::TryAgain;
106
};
107
108
let diff = diff.as_();
109
110
if diff > T::BITMASK_THRESHOLD {
111
return DictionaryDecision::TryAgain;
112
}
113
114
let mut seen_mask = MutableBitmap::from_len_zeroed(diff + 1);
115
116
let array = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
117
118
if array.has_nulls() {
119
for v in array.non_null_values_iter() {
120
let offset = (v - min).as_();
121
debug_assert!(offset <= diff);
122
123
unsafe {
124
seen_mask.set_unchecked(offset, true);
125
}
126
}
127
} else {
128
for v in array.values_iter() {
129
let offset = (*v - min).as_();
130
debug_assert!(offset <= diff);
131
132
unsafe {
133
seen_mask.set_unchecked(offset, true);
134
}
135
}
136
}
137
138
let cardinality = seen_mask.set_bits();
139
140
let mut is_worth_it = false;
141
142
is_worth_it |= cardinality <= T::DELTA_THRESHOLD;
143
is_worth_it |= (cardinality as f64) / (array.len() as f64) < 0.75;
144
145
if !is_worth_it {
146
return DictionaryDecision::NotWorth;
147
}
148
149
let seen_mask = seen_mask.freeze();
150
151
// SAFETY: We just did the calculation for this.
152
let indexes = seen_mask
153
.true_idx_iter()
154
.map(|idx| T::from_start_and_offset(min, idx));
155
let indexes = unsafe { TrustMyLength::new(indexes, cardinality) };
156
let indexes = indexes.collect_trusted::<Vec<_>>();
157
158
let mut lookup = vec![0u16; diff + 1];
159
160
for (i, &idx) in indexes.iter().enumerate() {
161
lookup[(idx - min).as_()] = i as u16;
162
}
163
164
use ArrowDataType as DT;
165
let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), indexes.into(), None);
166
let values = Box::new(values);
167
168
let keys: Buffer<u32> = array
169
.as_any()
170
.downcast_ref::<PrimitiveArray<T>>()
171
.unwrap()
172
.values()
173
.iter()
174
.map(|v| {
175
// @NOTE:
176
// Since the values might contain nulls which have a undefined value. We just
177
// clamp the values to between the min and max value. This way, they will still
178
// be valid dictionary keys.
179
let idx = *v.clamp(&min, &max) - min;
180
let value = unsafe { lookup.get_unchecked(idx.as_()) };
181
(*value).into()
182
})
183
.collect();
184
185
let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned());
186
DictionaryDecision::Found(
187
DictionaryArray::<u32>::try_new(
188
ArrowDataType::Dictionary(
189
IntegerType::UInt32,
190
Box::new(DT::from(T::PRIMITIVE)),
191
false, // @TODO: This might be able to be set to true?
192
),
193
keys,
194
values,
195
)
196
.unwrap(),
197
)
198
}
199
200
pub(crate) fn encode_as_dictionary_optional(
201
array: &dyn Array,
202
nested: &[Nested],
203
type_: PrimitiveType,
204
options: WriteOptions,
205
) -> Option<PolarsResult<DynIter<'static, PolarsResult<Page>>>> {
206
if array.is_empty() {
207
let array = DictionaryArray::<u32>::new_empty(ArrowDataType::Dictionary(
208
IntegerType::UInt32,
209
Box::new(array.dtype().clone()),
210
false, // @TODO: This might be able to be set to true?
211
));
212
213
return Some(array_to_pages(
214
&array,
215
type_,
216
nested,
217
options,
218
Encoding::RleDictionary,
219
));
220
}
221
222
use arrow::types::PrimitiveType as PT;
223
let fast_dictionary = match array.dtype().to_physical_type() {
224
PhysicalType::Primitive(pt) => match pt {
225
PT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array),
226
PT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array),
227
PT::Int32 => min_max_integer_encode_as_dictionary_optional::<_, i32>(array),
228
PT::Int64 => min_max_integer_encode_as_dictionary_optional::<_, i64>(array),
229
PT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array),
230
PT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array),
231
PT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array),
232
PT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array),
233
_ => DictionaryDecision::TryAgain,
234
},
235
_ => DictionaryDecision::TryAgain,
236
};
237
238
match fast_dictionary {
239
DictionaryDecision::NotWorth => return None,
240
DictionaryDecision::Found(dictionary_array) => {
241
return Some(array_to_pages(
242
&dictionary_array,
243
type_,
244
nested,
245
options,
246
Encoding::RleDictionary,
247
));
248
},
249
DictionaryDecision::TryAgain => {},
250
}
251
252
let dtype = Box::new(array.dtype().clone());
253
254
let estimated_cardinality = polars_compute::cardinality::estimate_cardinality(array);
255
256
if array.len() > 128 && (estimated_cardinality as f64) / (array.len() as f64) > 0.75 {
257
return None;
258
}
259
260
// This does the group by.
261
let array = polars_compute::cast::cast(
262
array,
263
&ArrowDataType::Dictionary(IntegerType::UInt32, dtype, false),
264
Default::default(),
265
)
266
.ok()?;
267
268
let array = array
269
.as_any()
270
.downcast_ref::<DictionaryArray<u32>>()
271
.unwrap();
272
273
Some(array_to_pages(
274
array,
275
type_,
276
nested,
277
options,
278
Encoding::RleDictionary,
279
))
280
}
281
282
fn serialize_def_levels_simple(
283
validity: Option<&Bitmap>,
284
length: usize,
285
is_optional: bool,
286
options: WriteOptions,
287
buffer: &mut Vec<u8>,
288
) -> PolarsResult<()> {
289
utils::write_def_levels(buffer, is_optional, validity, length, options.version)
290
}
291
292
fn serialize_keys_values<K: DictionaryKey>(
293
array: &DictionaryArray<K>,
294
validity: Option<&Bitmap>,
295
buffer: &mut Vec<u8>,
296
) -> PolarsResult<()> {
297
let keys = array.keys_values_iter().map(|x| x as u32);
298
if let Some(validity) = validity {
299
// discard indices whose values are null.
300
let keys = keys
301
.zip(validity.iter())
302
.filter(|&(_key, is_valid)| is_valid)
303
.map(|(key, _is_valid)| key);
304
let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64);
305
306
let keys = utils::ExactSizedIter::new(keys, array.len() - validity.unset_bits());
307
308
// num_bits as a single byte
309
buffer.push(num_bits as u8);
310
311
// followed by the encoded indices.
312
Ok(encode::<u32, _, _>(buffer, keys, num_bits)?)
313
} else {
314
let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64);
315
316
// num_bits as a single byte
317
buffer.push(num_bits as u8);
318
319
// followed by the encoded indices.
320
Ok(encode::<u32, _, _>(buffer, keys, num_bits)?)
321
}
322
}
323
324
fn serialize_levels(
325
validity: Option<&Bitmap>,
326
length: usize,
327
type_: &PrimitiveType,
328
nested: &[Nested],
329
options: WriteOptions,
330
buffer: &mut Vec<u8>,
331
) -> PolarsResult<(usize, usize)> {
332
if nested.len() == 1 {
333
let is_optional = is_nullable(&type_.field_info);
334
serialize_def_levels_simple(validity, length, is_optional, options, buffer)?;
335
let definition_levels_byte_length = buffer.len();
336
Ok((0, definition_levels_byte_length))
337
} else {
338
nested::write_rep_and_def(options.version, nested, buffer)
339
}
340
}
341
342
fn normalized_validity<K: DictionaryKey>(array: &DictionaryArray<K>) -> Option<Bitmap> {
343
match (array.keys().validity(), array.values().validity()) {
344
(None, None) => None,
345
(keys, None) => keys.cloned(),
346
// The values can have a different length than the keys
347
(_, Some(_values)) => {
348
let iter = (0..array.len()).map(|i| unsafe { !array.is_null_unchecked(i) });
349
MutableBitmap::from_trusted_len_iter(iter).into()
350
},
351
}
352
}
353
354
fn serialize_keys<K: DictionaryKey>(
355
array: &DictionaryArray<K>,
356
type_: PrimitiveType,
357
nested: &[Nested],
358
statistics: Option<ParquetStatistics>,
359
options: WriteOptions,
360
) -> DynIter<'static, PolarsResult<Page>> {
361
let number_of_rows = nested[0].len();
362
let byte_size = estimated_bytes_size(array.keys());
363
364
let array = array.clone();
365
let nested = nested.to_vec();
366
367
let pages =
368
row_slice_ranges(number_of_rows, byte_size, options).map(move |(offset, length)| {
369
let mut sliced_array = array.clone();
370
let mut sliced_nested = nested.clone();
371
slice_parquet_array(&mut sliced_array, &mut sliced_nested, offset, length);
372
373
serialize_keys_range(
374
&sliced_array,
375
&type_,
376
&sliced_nested,
377
statistics.clone(),
378
options,
379
)
380
});
381
382
DynIter::new(pages)
383
}
384
385
fn serialize_keys_range<K: DictionaryKey>(
386
array: &DictionaryArray<K>,
387
type_: &PrimitiveType,
388
nested: &[Nested],
389
statistics: Option<ParquetStatistics>,
390
options: WriteOptions,
391
) -> PolarsResult<Page> {
392
let mut buffer = vec![];
393
394
// Parquet only accepts a single validity - we "&" the validities into a single one
395
// and ignore keys whose _value_ is null.
396
let validity = normalized_validity(array);
397
398
let (repetition_levels_byte_length, definition_levels_byte_length) = serialize_levels(
399
validity.as_ref(),
400
array.len(),
401
type_,
402
nested,
403
options,
404
&mut buffer,
405
)?;
406
407
serialize_keys_values(array, validity.as_ref(), &mut buffer)?;
408
409
let (num_values, num_rows) = if nested.len() == 1 {
410
(array.len(), array.len())
411
} else {
412
(nested::num_values(nested), nested[0].len())
413
};
414
415
utils::build_plain_page(
416
buffer,
417
num_values,
418
num_rows,
419
array.null_count(),
420
repetition_levels_byte_length,
421
definition_levels_byte_length,
422
statistics,
423
type_.clone(),
424
options,
425
Encoding::RleDictionary,
426
)
427
.map(Page::Data)
428
}
429
430
macro_rules! dyn_prim {
431
($from:ty, $to:ty, $array:expr, $options:expr, $type_:expr) => {{
432
let values = $array.values().as_any().downcast_ref().unwrap();
433
434
let buffer =
435
primitive_encode_plain::<$from, $to>(values, EncodeNullability::new(false), vec![]);
436
437
let stats: Option<ParquetStatistics> = if !$options.statistics.is_empty() {
438
let mut stats = primitive_build_statistics::<$from, $to>(
439
values,
440
$type_.clone(),
441
&$options.statistics,
442
);
443
stats.null_count = Some($array.null_count() as i64);
444
Some(stats.serialize())
445
} else {
446
None
447
};
448
(
449
DictPage::new(CowBuffer::Owned(buffer), values.len(), false),
450
stats,
451
)
452
}};
453
}
454
455
pub fn array_to_pages<K: DictionaryKey>(
456
array: &DictionaryArray<K>,
457
type_: PrimitiveType,
458
nested: &[Nested],
459
options: WriteOptions,
460
encoding: Encoding,
461
) -> PolarsResult<DynIter<'static, PolarsResult<Page>>> {
462
match encoding {
463
Encoding::PlainDictionary | Encoding::RleDictionary => {
464
// write DictPage
465
let (dict_page, mut statistics): (_, Option<ParquetStatistics>) = match array
466
.values()
467
.dtype()
468
.to_storage()
469
{
470
ArrowDataType::Int8 => dyn_prim!(i8, i32, array, options, type_),
471
ArrowDataType::Int16 => dyn_prim!(i16, i32, array, options, type_),
472
ArrowDataType::Int32 | ArrowDataType::Date32 | ArrowDataType::Time32(_) => {
473
dyn_prim!(i32, i32, array, options, type_)
474
},
475
ArrowDataType::Int64
476
| ArrowDataType::Date64
477
| ArrowDataType::Time64(_)
478
| ArrowDataType::Timestamp(_, _)
479
| ArrowDataType::Duration(_) => dyn_prim!(i64, i64, array, options, type_),
480
ArrowDataType::UInt8 => dyn_prim!(u8, i32, array, options, type_),
481
ArrowDataType::UInt16 => dyn_prim!(u16, i32, array, options, type_),
482
ArrowDataType::UInt32 => dyn_prim!(u32, i32, array, options, type_),
483
ArrowDataType::UInt64 => dyn_prim!(u64, i64, array, options, type_),
484
ArrowDataType::Float16 => dyn_prim!(pf16, f32, array, options, type_),
485
ArrowDataType::Float32 => dyn_prim!(f32, f32, array, options, type_),
486
ArrowDataType::Float64 => dyn_prim!(f64, f64, array, options, type_),
487
ArrowDataType::LargeUtf8 => {
488
let array = polars_compute::cast::cast(
489
array.values().as_ref(),
490
&ArrowDataType::LargeBinary,
491
Default::default(),
492
)
493
.unwrap();
494
let array = array.as_any().downcast_ref().unwrap();
495
496
let mut buffer = vec![];
497
binary_encode_plain::<i64>(array, EncodeNullability::Required, &mut buffer);
498
let stats = if options.has_statistics() {
499
Some(binary_build_statistics(
500
array,
501
type_.clone(),
502
&options.statistics,
503
))
504
} else {
505
None
506
};
507
(
508
DictPage::new(CowBuffer::Owned(buffer), array.len(), false),
509
stats,
510
)
511
},
512
ArrowDataType::BinaryView => {
513
let array = array
514
.values()
515
.as_any()
516
.downcast_ref::<BinaryViewArray>()
517
.unwrap();
518
let mut buffer = vec![];
519
binview::encode_plain(array, EncodeNullability::Required, &mut buffer);
520
521
let stats = if options.has_statistics() {
522
Some(binview::build_statistics(
523
array,
524
type_.clone(),
525
&options.statistics,
526
))
527
} else {
528
None
529
};
530
(
531
DictPage::new(CowBuffer::Owned(buffer), array.len(), false),
532
stats,
533
)
534
},
535
ArrowDataType::Utf8View => {
536
let array = array
537
.values()
538
.as_any()
539
.downcast_ref::<Utf8ViewArray>()
540
.unwrap()
541
.to_binview();
542
let mut buffer = vec![];
543
binview::encode_plain(&array, EncodeNullability::Required, &mut buffer);
544
545
let stats = if options.has_statistics() {
546
Some(binview::build_statistics(
547
&array,
548
type_.clone(),
549
&options.statistics,
550
))
551
} else {
552
None
553
};
554
(
555
DictPage::new(CowBuffer::Owned(buffer), array.len(), false),
556
stats,
557
)
558
},
559
ArrowDataType::LargeBinary => {
560
let values = array.values().as_any().downcast_ref().unwrap();
561
562
let mut buffer = vec![];
563
binary_encode_plain::<i64>(values, EncodeNullability::Required, &mut buffer);
564
let stats = if options.has_statistics() {
565
Some(binary_build_statistics(
566
values,
567
type_.clone(),
568
&options.statistics,
569
))
570
} else {
571
None
572
};
573
(
574
DictPage::new(CowBuffer::Owned(buffer), values.len(), false),
575
stats,
576
)
577
},
578
ArrowDataType::FixedSizeBinary(_) => {
579
let mut buffer = vec![];
580
let array = array.values().as_any().downcast_ref().unwrap();
581
fixed_binary_encode_plain(array, EncodeNullability::Required, &mut buffer);
582
let stats = if options.has_statistics() {
583
let stats = fixed_binary_build_statistics(
584
array,
585
type_.clone(),
586
&options.statistics,
587
);
588
Some(stats.serialize())
589
} else {
590
None
591
};
592
(
593
DictPage::new(CowBuffer::Owned(buffer), array.len(), false),
594
stats,
595
)
596
},
597
other => {
598
polars_bail!(
599
nyi =
600
"Writing dictionary arrays to parquet only support data type {other:?}"
601
)
602
},
603
};
604
605
if let Some(stats) = &mut statistics {
606
stats.null_count = Some(array.null_count() as i64)
607
}
608
609
// write DataPages pointing to DictPage
610
let data_pages = serialize_keys(array, type_, nested, statistics, options);
611
612
Ok(DynIter::new(
613
std::iter::once(Ok(Page::Dict(dict_page))).chain(data_pages),
614
))
615
},
616
_ => polars_bail!(nyi = "Dictionary arrays only support dictionary encoding"),
617
}
618
}
619
620