Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs
8482 views
1
use arrow::bitmap::Bitmap;
2
use arrow::bitmap::bitmask::BitMask;
3
use arrow::types::AlignedBytes;
4
5
use super::{
6
IndexMapping, no_more_bitpacked_values, oob_dict_idx, optional_skip_whole_chunks,
7
verify_dict_indices,
8
};
9
use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder};
10
use crate::parquet::error::ParquetResult;
11
12
/// Decoding kernel for optional dictionary encoded.
13
#[inline(never)]
14
pub fn decode<B: AlignedBytes, D: IndexMapping<Output = B>>(
15
mut values: HybridRleDecoder<'_>,
16
dict: D,
17
mut validity: Bitmap,
18
target: &mut Vec<B>,
19
mut num_rows_to_skip: usize,
20
) -> ParquetResult<()> {
21
debug_assert!(num_rows_to_skip <= validity.len());
22
23
let num_rows = validity.len() - num_rows_to_skip;
24
let end_length = target.len() + num_rows;
25
26
target.reserve(num_rows);
27
28
// Remove any leading and trailing nulls. This has two benefits:
29
// 1. It increases the chance of dispatching to the faster kernel (e.g. for sorted data)
30
// 2. It reduces the amount of iterations in the main loop and replaces it with `memset`s
31
let leading_nulls = validity.take_leading_zeros();
32
let trailing_nulls = validity.take_trailing_zeros();
33
34
// Special case: all values are skipped, just add the trailing null.
35
if num_rows_to_skip >= leading_nulls + validity.len() {
36
target.resize(end_length, B::zeroed());
37
return Ok(());
38
}
39
40
values.limit_to(validity.set_bits());
41
42
// Add the leading nulls
43
if num_rows_to_skip < leading_nulls {
44
target.resize(target.len() + leading_nulls - num_rows_to_skip, B::zeroed());
45
num_rows_to_skip = 0;
46
} else {
47
num_rows_to_skip -= leading_nulls;
48
}
49
50
if validity.set_bits() == validity.len() {
51
// Dispatch to the required kernel if all rows are valid anyway.
52
super::required::decode(values, dict, target, num_rows_to_skip)?;
53
} else {
54
if dict.is_empty() {
55
return Err(oob_dict_idx());
56
}
57
58
let mut num_values_to_skip = 0;
59
if num_rows_to_skip > 0 {
60
num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits();
61
}
62
63
let mut validity = BitMask::from_bitmap(&validity);
64
let mut values_buffer = [0u32; 128];
65
let values_buffer = &mut values_buffer;
66
67
// Skip over any whole HybridRleChunks
68
optional_skip_whole_chunks(
69
&mut values,
70
&mut validity,
71
&mut num_rows_to_skip,
72
&mut num_values_to_skip,
73
)?;
74
75
while let Some(chunk) = values.next_chunk()? {
76
debug_assert!(num_values_to_skip < chunk.len() || chunk.len() == 0);
77
78
match chunk {
79
HybridRleChunk::Rle(value, length) => {
80
if length == 0 {
81
continue;
82
}
83
84
// If we know that we have `length` times `value` that we can append, but there
85
// might be nulls in between those values.
86
//
87
// 1. See how many `num_rows = valid + invalid` values `length` would entail.
88
// This is done with `nth_set_bit_idx` on the validity mask.
89
// 2. Fill `num_rows` values into the target buffer.
90
// 3. Advance the validity mask by `num_rows` values.
91
92
let Some(value) = dict.get(value) else {
93
return Err(oob_dict_idx());
94
};
95
96
// We have `length` values but they may span more rows due to interspersed nulls.
97
//
98
// Example: validity = [1,0,0,0,1,0,0,0,1,0,0, 1, 1, 1] and length = 3
99
// positions: 0 1 2 3 4 5 6 7 8 9 10 11 12 13
100
// indices: 0 1 2 3 4 5 (of set bits)
101
//
102
// First RLE chunk owns 3 values at positions 0, 4, 8 (sparse, many nulls).
103
// Second RLE chunk owns values at positions 11, 12, 13.
104
//
105
// Correct: nth_set_bit_idx(length-1 = 2, 0) = 8 → rows = 8+1 = 9 (rows 0-8) ✓
106
// Bug: nth_set_bit_idx(length = 3, 0) = 11 → rows = 11 (rows 0-10!)
107
//
108
// The bug claims rows 9, 10 which are nulls belonging to the second chunk.
109
let num_chunk_rows = validity
110
.nth_set_bit_idx(length - 1, 0)
111
.map_or(validity.len(), |v| v + 1);
112
113
validity.advance_by(num_chunk_rows);
114
115
target.resize(target.len() + num_chunk_rows - num_rows_to_skip, value);
116
},
117
HybridRleChunk::Bitpacked(mut decoder) => {
118
let num_rows_for_decoder = validity
119
.nth_set_bit_idx(decoder.len(), 0)
120
.unwrap_or(validity.len());
121
122
let mut chunked = decoder.chunked();
123
124
let mut buffer_part_idx = 0;
125
let mut values_offset = 0;
126
let mut num_buffered: usize = 0;
127
128
let mut decoder_validity;
129
(decoder_validity, validity) = validity.split_at(num_rows_for_decoder);
130
131
// Skip over any remaining values.
132
if num_rows_to_skip > 0 {
133
decoder_validity.advance_by(num_rows_to_skip);
134
135
chunked.decoder.skip_chunks(num_values_to_skip / 32);
136
num_values_to_skip %= 32;
137
138
if num_values_to_skip > 0 {
139
let buffer_part = <&mut [u32; 32]>::try_from(
140
&mut values_buffer[buffer_part_idx * 32..][..32],
141
)
142
.unwrap();
143
let Some(num_added) = chunked.next_into(buffer_part) else {
144
return Err(no_more_bitpacked_values());
145
};
146
147
debug_assert!(num_values_to_skip <= num_added);
148
verify_dict_indices(buffer_part, dict.len())?;
149
150
values_offset += num_values_to_skip;
151
num_buffered += num_added - num_values_to_skip;
152
buffer_part_idx += 1;
153
}
154
}
155
156
let mut iter = |v: u64, n: usize| {
157
while num_buffered < v.count_ones() as usize {
158
buffer_part_idx %= 4;
159
160
let buffer_part = <&mut [u32; 32]>::try_from(
161
&mut values_buffer[buffer_part_idx * 32..][..32],
162
)
163
.unwrap();
164
let Some(num_added) = chunked.next_into(buffer_part) else {
165
return Err(no_more_bitpacked_values());
166
};
167
168
verify_dict_indices(buffer_part, dict.len())?;
169
170
num_buffered += num_added;
171
172
buffer_part_idx += 1;
173
}
174
175
let mut num_read = 0;
176
177
target.extend((0..n).map(|i| {
178
let idx = values_buffer[(values_offset + num_read) % 128];
179
num_read += ((v >> i) & 1) as usize;
180
181
// SAFETY:
182
// 1. `values_buffer` starts out as only zeros, which we know is in the
183
// dictionary following the original `dict.is_empty` check.
184
// 2. Each time we write to `values_buffer`, it is followed by a
185
// `verify_dict_indices`.
186
unsafe { dict.get_unchecked(idx) }
187
}));
188
189
values_offset += num_read;
190
values_offset %= 128;
191
num_buffered -= num_read;
192
193
ParquetResult::Ok(())
194
};
195
196
let mut v_iter = decoder_validity.fast_iter_u56();
197
for v in v_iter.by_ref() {
198
iter(v, 56)?;
199
}
200
201
let (v, vl) = v_iter.remainder();
202
iter(v, vl)?;
203
},
204
}
205
206
num_rows_to_skip = 0;
207
num_values_to_skip = 0;
208
}
209
}
210
211
// Add back the trailing nulls
212
debug_assert_eq!(target.len(), end_length - trailing_nulls);
213
target.resize(end_length, B::zeroed());
214
215
Ok(())
216
}
217
218
#[cfg(test)]
219
mod tests {
220
use arrow::bitmap::Bitmap;
221
use arrow::types::Bytes4Alignment4;
222
223
use super::decode;
224
use crate::parquet::encoding::hybrid_rle::{Encoder, HybridRleDecoder};
225
226
/// Position: 0 1 2 3 4 5 6 7 8 9 10 11 12 13
227
/// Validity: 1 0 0 0 1 0 0 0 1 0 0 1 1 1
228
/// Value: 66 66 66 99 99 99
229
/// |----chunk0 (3 values)----| |-chunk1-|
230
///
231
/// Chunk0 owns rows 0-8 (9 rows), chunk1 owns rows 9-13 (5 rows).
232
/// Positions 9,10 are nulls that belong to chunk1, so they get filled with 99.
233
///
234
/// BUG: `nth_set_bit_idx(length=3, 0)` returns 11, so chunk0 claims rows 0-10.
235
/// Positions 9,10 get filled with 66 (wrong - they belong to chunk1).
236
///
237
/// FIX: `nth_set_bit_idx(length-1=2, 0) + 1` returns 9, so chunk0 claims rows 0-8.
238
/// Positions 9,10 get filled with 99 (correct - they belong to chunk1).
239
#[test]
240
fn test_rle_decode_with_sparse_nulls() {
241
// Bitmap bits (LSB first within each byte):
242
// Byte 0: positions 0-7 = 1,0,0,0,1,0,0,0 = 0b00010001
243
// Byte 1: positions 8-13 = 1,0,0,1,1,1 = 0b00111001
244
let validity_bytes: Vec<u8> = vec![0b00010001, 0b00111001];
245
let validity = Bitmap::try_new(validity_bytes, 14).unwrap();
246
247
let mut encoded = Vec::new();
248
u32::run_length_encode(&mut encoded, 3, 0, 1).unwrap(); // 3x dict index 0
249
u32::run_length_encode(&mut encoded, 3, 1, 1).unwrap(); // 3x dict index 1
250
251
let dict: &[Bytes4Alignment4] = bytemuck::cast_slice(&[66u32, 99u32]);
252
let decoder = HybridRleDecoder::new(&encoded, 1, 6); // 6 total values
253
254
let mut target = Vec::new();
255
decode(decoder, dict, validity, &mut target, 0).unwrap();
256
257
assert_eq!(target.len(), 14, "should have 14 rows");
258
259
// Valid positions in chunk0 (dict[0]=66): positions 0, 4, 8
260
assert_eq!(target[0], dict[0], "position 0 should be dict[0]");
261
assert_eq!(target[4], dict[0], "position 4 should be dict[0]");
262
assert_eq!(target[8], dict[0], "position 8 should be dict[0]");
263
264
// Valid positions in chunk1 (dict[1]=99): positions 11, 12, 13
265
assert_eq!(target[11], dict[1], "position 11 should be dict[1]");
266
assert_eq!(target[12], dict[1], "position 12 should be dict[1]");
267
assert_eq!(target[13], dict[1], "position 13 should be dict[1]");
268
269
// Null positions 9,10 belong to chunk1, so they should be filled with dict[1].
270
// BUG: These would be dict[0] if chunk0 incorrectly claims rows 0-10.
271
assert_eq!(target[9], dict[1], "position 9 (null) should be dict[1]");
272
assert_eq!(target[10], dict[1], "position 10 (null) should be dict[1]");
273
}
274
}
275
276