Path: blob/main/crates/polars-parquet/src/arrow/read/deserialize/dictionary_encoded/optional.rs
8482 views
use arrow::bitmap::Bitmap;1use arrow::bitmap::bitmask::BitMask;2use arrow::types::AlignedBytes;34use super::{5IndexMapping, no_more_bitpacked_values, oob_dict_idx, optional_skip_whole_chunks,6verify_dict_indices,7};8use crate::parquet::encoding::hybrid_rle::{HybridRleChunk, HybridRleDecoder};9use crate::parquet::error::ParquetResult;1011/// Decoding kernel for optional dictionary encoded.12#[inline(never)]13pub fn decode<B: AlignedBytes, D: IndexMapping<Output = B>>(14mut values: HybridRleDecoder<'_>,15dict: D,16mut validity: Bitmap,17target: &mut Vec<B>,18mut num_rows_to_skip: usize,19) -> ParquetResult<()> {20debug_assert!(num_rows_to_skip <= validity.len());2122let num_rows = validity.len() - num_rows_to_skip;23let end_length = target.len() + num_rows;2425target.reserve(num_rows);2627// Remove any leading and trailing nulls. This has two benefits:28// 1. It increases the chance of dispatching to the faster kernel (e.g. for sorted data)29// 2. It reduces the amount of iterations in the main loop and replaces it with `memset`s30let leading_nulls = validity.take_leading_zeros();31let trailing_nulls = validity.take_trailing_zeros();3233// Special case: all values are skipped, just add the trailing null.34if num_rows_to_skip >= leading_nulls + validity.len() {35target.resize(end_length, B::zeroed());36return Ok(());37}3839values.limit_to(validity.set_bits());4041// Add the leading nulls42if num_rows_to_skip < leading_nulls {43target.resize(target.len() + leading_nulls - num_rows_to_skip, B::zeroed());44num_rows_to_skip = 0;45} else {46num_rows_to_skip -= leading_nulls;47}4849if validity.set_bits() == validity.len() {50// Dispatch to the required kernel if all rows are valid anyway.51super::required::decode(values, dict, target, num_rows_to_skip)?;52} else {53if dict.is_empty() {54return Err(oob_dict_idx());55}5657let mut num_values_to_skip = 0;58if num_rows_to_skip > 0 {59num_values_to_skip = validity.clone().sliced(0, num_rows_to_skip).set_bits();60}6162let mut validity = BitMask::from_bitmap(&validity);63let mut values_buffer = [0u32; 128];64let values_buffer = &mut values_buffer;6566// Skip over any whole HybridRleChunks67optional_skip_whole_chunks(68&mut values,69&mut validity,70&mut num_rows_to_skip,71&mut num_values_to_skip,72)?;7374while let Some(chunk) = values.next_chunk()? {75debug_assert!(num_values_to_skip < chunk.len() || chunk.len() == 0);7677match chunk {78HybridRleChunk::Rle(value, length) => {79if length == 0 {80continue;81}8283// If we know that we have `length` times `value` that we can append, but there84// might be nulls in between those values.85//86// 1. See how many `num_rows = valid + invalid` values `length` would entail.87// This is done with `nth_set_bit_idx` on the validity mask.88// 2. Fill `num_rows` values into the target buffer.89// 3. Advance the validity mask by `num_rows` values.9091let Some(value) = dict.get(value) else {92return Err(oob_dict_idx());93};9495// We have `length` values but they may span more rows due to interspersed nulls.96//97// Example: validity = [1,0,0,0,1,0,0,0,1,0,0, 1, 1, 1] and length = 398// positions: 0 1 2 3 4 5 6 7 8 9 10 11 12 1399// indices: 0 1 2 3 4 5 (of set bits)100//101// First RLE chunk owns 3 values at positions 0, 4, 8 (sparse, many nulls).102// Second RLE chunk owns values at positions 11, 12, 13.103//104// Correct: nth_set_bit_idx(length-1 = 2, 0) = 8 → rows = 8+1 = 9 (rows 0-8) ✓105// Bug: nth_set_bit_idx(length = 3, 0) = 11 → rows = 11 (rows 0-10!)106//107// The bug claims rows 9, 10 which are nulls belonging to the second chunk.108let num_chunk_rows = validity109.nth_set_bit_idx(length - 1, 0)110.map_or(validity.len(), |v| v + 1);111112validity.advance_by(num_chunk_rows);113114target.resize(target.len() + num_chunk_rows - num_rows_to_skip, value);115},116HybridRleChunk::Bitpacked(mut decoder) => {117let num_rows_for_decoder = validity118.nth_set_bit_idx(decoder.len(), 0)119.unwrap_or(validity.len());120121let mut chunked = decoder.chunked();122123let mut buffer_part_idx = 0;124let mut values_offset = 0;125let mut num_buffered: usize = 0;126127let mut decoder_validity;128(decoder_validity, validity) = validity.split_at(num_rows_for_decoder);129130// Skip over any remaining values.131if num_rows_to_skip > 0 {132decoder_validity.advance_by(num_rows_to_skip);133134chunked.decoder.skip_chunks(num_values_to_skip / 32);135num_values_to_skip %= 32;136137if num_values_to_skip > 0 {138let buffer_part = <&mut [u32; 32]>::try_from(139&mut values_buffer[buffer_part_idx * 32..][..32],140)141.unwrap();142let Some(num_added) = chunked.next_into(buffer_part) else {143return Err(no_more_bitpacked_values());144};145146debug_assert!(num_values_to_skip <= num_added);147verify_dict_indices(buffer_part, dict.len())?;148149values_offset += num_values_to_skip;150num_buffered += num_added - num_values_to_skip;151buffer_part_idx += 1;152}153}154155let mut iter = |v: u64, n: usize| {156while num_buffered < v.count_ones() as usize {157buffer_part_idx %= 4;158159let buffer_part = <&mut [u32; 32]>::try_from(160&mut values_buffer[buffer_part_idx * 32..][..32],161)162.unwrap();163let Some(num_added) = chunked.next_into(buffer_part) else {164return Err(no_more_bitpacked_values());165};166167verify_dict_indices(buffer_part, dict.len())?;168169num_buffered += num_added;170171buffer_part_idx += 1;172}173174let mut num_read = 0;175176target.extend((0..n).map(|i| {177let idx = values_buffer[(values_offset + num_read) % 128];178num_read += ((v >> i) & 1) as usize;179180// SAFETY:181// 1. `values_buffer` starts out as only zeros, which we know is in the182// dictionary following the original `dict.is_empty` check.183// 2. Each time we write to `values_buffer`, it is followed by a184// `verify_dict_indices`.185unsafe { dict.get_unchecked(idx) }186}));187188values_offset += num_read;189values_offset %= 128;190num_buffered -= num_read;191192ParquetResult::Ok(())193};194195let mut v_iter = decoder_validity.fast_iter_u56();196for v in v_iter.by_ref() {197iter(v, 56)?;198}199200let (v, vl) = v_iter.remainder();201iter(v, vl)?;202},203}204205num_rows_to_skip = 0;206num_values_to_skip = 0;207}208}209210// Add back the trailing nulls211debug_assert_eq!(target.len(), end_length - trailing_nulls);212target.resize(end_length, B::zeroed());213214Ok(())215}216217#[cfg(test)]218mod tests {219use arrow::bitmap::Bitmap;220use arrow::types::Bytes4Alignment4;221222use super::decode;223use crate::parquet::encoding::hybrid_rle::{Encoder, HybridRleDecoder};224225/// Position: 0 1 2 3 4 5 6 7 8 9 10 11 12 13226/// Validity: 1 0 0 0 1 0 0 0 1 0 0 1 1 1227/// Value: 66 66 66 99 99 99228/// |----chunk0 (3 values)----| |-chunk1-|229///230/// Chunk0 owns rows 0-8 (9 rows), chunk1 owns rows 9-13 (5 rows).231/// Positions 9,10 are nulls that belong to chunk1, so they get filled with 99.232///233/// BUG: `nth_set_bit_idx(length=3, 0)` returns 11, so chunk0 claims rows 0-10.234/// Positions 9,10 get filled with 66 (wrong - they belong to chunk1).235///236/// FIX: `nth_set_bit_idx(length-1=2, 0) + 1` returns 9, so chunk0 claims rows 0-8.237/// Positions 9,10 get filled with 99 (correct - they belong to chunk1).238#[test]239fn test_rle_decode_with_sparse_nulls() {240// Bitmap bits (LSB first within each byte):241// Byte 0: positions 0-7 = 1,0,0,0,1,0,0,0 = 0b00010001242// Byte 1: positions 8-13 = 1,0,0,1,1,1 = 0b00111001243let validity_bytes: Vec<u8> = vec![0b00010001, 0b00111001];244let validity = Bitmap::try_new(validity_bytes, 14).unwrap();245246let mut encoded = Vec::new();247u32::run_length_encode(&mut encoded, 3, 0, 1).unwrap(); // 3x dict index 0248u32::run_length_encode(&mut encoded, 3, 1, 1).unwrap(); // 3x dict index 1249250let dict: &[Bytes4Alignment4] = bytemuck::cast_slice(&[66u32, 99u32]);251let decoder = HybridRleDecoder::new(&encoded, 1, 6); // 6 total values252253let mut target = Vec::new();254decode(decoder, dict, validity, &mut target, 0).unwrap();255256assert_eq!(target.len(), 14, "should have 14 rows");257258// Valid positions in chunk0 (dict[0]=66): positions 0, 4, 8259assert_eq!(target[0], dict[0], "position 0 should be dict[0]");260assert_eq!(target[4], dict[0], "position 4 should be dict[0]");261assert_eq!(target[8], dict[0], "position 8 should be dict[0]");262263// Valid positions in chunk1 (dict[1]=99): positions 11, 12, 13264assert_eq!(target[11], dict[1], "position 11 should be dict[1]");265assert_eq!(target[12], dict[1], "position 12 should be dict[1]");266assert_eq!(target[13], dict[1], "position 13 should be dict[1]");267268// Null positions 9,10 belong to chunk1, so they should be filled with dict[1].269// BUG: These would be dict[0] if chunk0 incorrectly claims rows 0-10.270assert_eq!(target[9], dict[1], "position 9 (null) should be dict[1]");271assert_eq!(target[10], dict[1], "position 10 (null) should be dict[1]");272}273}274275276