Path: blob/main/crates/polars-parquet/src/arrow/write/dictionary.rs
6940 views
use arrow::array::{1Array, BinaryViewArray, DictionaryArray, DictionaryKey, PrimitiveArray, Utf8ViewArray,2};3use arrow::bitmap::{Bitmap, MutableBitmap};4use arrow::buffer::Buffer;5use arrow::datatypes::{ArrowDataType, IntegerType, PhysicalType};6use arrow::legacy::utils::CustomIterTools;7use arrow::trusted_len::TrustMyLength;8use arrow::types::NativeType;9use polars_compute::min_max::MinMaxKernel;10use polars_error::{PolarsResult, polars_bail};1112use super::binary::{13build_statistics as binary_build_statistics, encode_plain as binary_encode_plain,14};15use super::fixed_size_binary::{16build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain,17};18use super::pages::PrimitiveNested;19use super::primitive::{20build_statistics as primitive_build_statistics, encode_plain as primitive_encode_plain,21};22use super::{EncodeNullability, Nested, WriteOptions, binview, nested};23use crate::arrow::read::schema::is_nullable;24use crate::arrow::write::{slice_nested_leaf, utils};25use crate::parquet::CowBuffer;26use crate::parquet::encoding::Encoding;27use crate::parquet::encoding::hybrid_rle::encode;28use crate::parquet::page::{DictPage, Page};29use crate::parquet::schema::types::PrimitiveType;30use crate::parquet::statistics::ParquetStatistics;31use crate::write::DynIter;3233trait MinMaxThreshold {34const DELTA_THRESHOLD: usize;35const BITMASK_THRESHOLD: usize;3637fn from_start_and_offset(start: Self, offset: usize) -> Self;38}3940macro_rules! minmaxthreshold_impls {41($($signed:ty, $unsigned:ty => $threshold:literal, $bm_threshold:expr,)+) => {42$(43impl MinMaxThreshold for $signed {44const DELTA_THRESHOLD: usize = $threshold;45const BITMASK_THRESHOLD: usize = $bm_threshold;4647fn from_start_and_offset(start: Self, offset: usize) -> Self {48start + ((offset as $unsigned) as $signed)49}50}51impl MinMaxThreshold for $unsigned {52const DELTA_THRESHOLD: usize = $threshold;53const BITMASK_THRESHOLD: usize = $bm_threshold;5455fn from_start_and_offset(start: Self, offset: usize) -> Self {56start + (offset as $unsigned)57}58}59)+60};61}6263minmaxthreshold_impls! {64i8, u8 => 16, u8::MAX as usize,65i16, u16 => 256, u16::MAX as usize,66i32, u32 => 512, u16::MAX as usize,67i64, u64 => 2048, u16::MAX as usize,68}6970enum DictionaryDecision {71NotWorth,72TryAgain,73Found(DictionaryArray<u32>),74}7576fn min_max_integer_encode_as_dictionary_optional<'a, E, T>(77array: &'a dyn Array,78) -> DictionaryDecision79where80E: std::fmt::Debug,81T: NativeType82+ MinMaxThreshold83+ std::cmp::Ord84+ TryInto<u32, Error = E>85+ std::ops::Sub<T, Output = T>86+ num_traits::CheckedSub87+ num_traits::cast::AsPrimitive<usize>,88std::ops::RangeInclusive<T>: Iterator<Item = T>,89PrimitiveArray<T>: MinMaxKernel<Scalar<'a> = T>,90{91let min_max = <PrimitiveArray<T> as MinMaxKernel>::min_max_ignore_nan_kernel(92array.as_any().downcast_ref().unwrap(),93);9495let Some((min, max)) = min_max else {96return DictionaryDecision::TryAgain;97};9899debug_assert!(max >= min, "{max} >= {min}");100let Some(diff) = max.checked_sub(&min) else {101return DictionaryDecision::TryAgain;102};103104let diff = diff.as_();105106if diff > T::BITMASK_THRESHOLD {107return DictionaryDecision::TryAgain;108}109110let mut seen_mask = MutableBitmap::from_len_zeroed(diff + 1);111112let array = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();113114if array.has_nulls() {115for v in array.non_null_values_iter() {116let offset = (v - min).as_();117debug_assert!(offset <= diff);118119unsafe {120seen_mask.set_unchecked(offset, true);121}122}123} else {124for v in array.values_iter() {125let offset = (*v - min).as_();126debug_assert!(offset <= diff);127128unsafe {129seen_mask.set_unchecked(offset, true);130}131}132}133134let cardinality = seen_mask.set_bits();135136let mut is_worth_it = false;137138is_worth_it |= cardinality <= T::DELTA_THRESHOLD;139is_worth_it |= (cardinality as f64) / (array.len() as f64) < 0.75;140141if !is_worth_it {142return DictionaryDecision::NotWorth;143}144145let seen_mask = seen_mask.freeze();146147// SAFETY: We just did the calculation for this.148let indexes = seen_mask149.true_idx_iter()150.map(|idx| T::from_start_and_offset(min, idx));151let indexes = unsafe { TrustMyLength::new(indexes, cardinality) };152let indexes = indexes.collect_trusted::<Vec<_>>();153154let mut lookup = vec![0u16; diff + 1];155156for (i, &idx) in indexes.iter().enumerate() {157lookup[(idx - min).as_()] = i as u16;158}159160use ArrowDataType as DT;161let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), indexes.into(), None);162let values = Box::new(values);163164let keys: Buffer<u32> = array165.as_any()166.downcast_ref::<PrimitiveArray<T>>()167.unwrap()168.values()169.iter()170.map(|v| {171// @NOTE:172// Since the values might contain nulls which have a undefined value. We just173// clamp the values to between the min and max value. This way, they will still174// be valid dictionary keys.175let idx = *v.clamp(&min, &max) - min;176let value = unsafe { lookup.get_unchecked(idx.as_()) };177(*value).into()178})179.collect();180181let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned());182DictionaryDecision::Found(183DictionaryArray::<u32>::try_new(184ArrowDataType::Dictionary(185IntegerType::UInt32,186Box::new(DT::from(T::PRIMITIVE)),187false, // @TODO: This might be able to be set to true?188),189keys,190values,191)192.unwrap(),193)194}195196pub(crate) fn encode_as_dictionary_optional(197array: &dyn Array,198nested: &[Nested],199type_: PrimitiveType,200options: WriteOptions,201) -> Option<PolarsResult<DynIter<'static, PolarsResult<Page>>>> {202if array.is_empty() {203let array = DictionaryArray::<u32>::new_empty(ArrowDataType::Dictionary(204IntegerType::UInt32,205Box::new(array.dtype().clone()),206false, // @TODO: This might be able to be set to true?207));208209return Some(array_to_pages(210&array,211type_,212nested,213options,214Encoding::RleDictionary,215));216}217218use arrow::types::PrimitiveType as PT;219let fast_dictionary = match array.dtype().to_physical_type() {220PhysicalType::Primitive(pt) => match pt {221PT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array),222PT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array),223PT::Int32 => min_max_integer_encode_as_dictionary_optional::<_, i32>(array),224PT::Int64 => min_max_integer_encode_as_dictionary_optional::<_, i64>(array),225PT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array),226PT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array),227PT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array),228PT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array),229_ => DictionaryDecision::TryAgain,230},231_ => DictionaryDecision::TryAgain,232};233234match fast_dictionary {235DictionaryDecision::NotWorth => return None,236DictionaryDecision::Found(dictionary_array) => {237return Some(array_to_pages(238&dictionary_array,239type_,240nested,241options,242Encoding::RleDictionary,243));244},245DictionaryDecision::TryAgain => {},246}247248let dtype = Box::new(array.dtype().clone());249250let estimated_cardinality = polars_compute::cardinality::estimate_cardinality(array);251252if array.len() > 128 && (estimated_cardinality as f64) / (array.len() as f64) > 0.75 {253return None;254}255256// This does the group by.257let array = polars_compute::cast::cast(258array,259&ArrowDataType::Dictionary(IntegerType::UInt32, dtype, false),260Default::default(),261)262.ok()?;263264let array = array265.as_any()266.downcast_ref::<DictionaryArray<u32>>()267.unwrap();268269Some(array_to_pages(270array,271type_,272nested,273options,274Encoding::RleDictionary,275))276}277278fn serialize_def_levels_simple(279validity: Option<&Bitmap>,280length: usize,281is_optional: bool,282options: WriteOptions,283buffer: &mut Vec<u8>,284) -> PolarsResult<()> {285utils::write_def_levels(buffer, is_optional, validity, length, options.version)286}287288fn serialize_keys_values<K: DictionaryKey>(289array: &DictionaryArray<K>,290validity: Option<&Bitmap>,291buffer: &mut Vec<u8>,292) -> PolarsResult<()> {293let keys = array.keys_values_iter().map(|x| x as u32);294if let Some(validity) = validity {295// discard indices whose values are null.296let keys = keys297.zip(validity.iter())298.filter(|&(_key, is_valid)| is_valid)299.map(|(key, _is_valid)| key);300let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64);301302let keys = utils::ExactSizedIter::new(keys, array.len() - validity.unset_bits());303304// num_bits as a single byte305buffer.push(num_bits as u8);306307// followed by the encoded indices.308Ok(encode::<u32, _, _>(buffer, keys, num_bits)?)309} else {310let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64);311312// num_bits as a single byte313buffer.push(num_bits as u8);314315// followed by the encoded indices.316Ok(encode::<u32, _, _>(buffer, keys, num_bits)?)317}318}319320fn serialize_levels(321validity: Option<&Bitmap>,322length: usize,323type_: &PrimitiveType,324nested: &[Nested],325options: WriteOptions,326buffer: &mut Vec<u8>,327) -> PolarsResult<(usize, usize)> {328if nested.len() == 1 {329let is_optional = is_nullable(&type_.field_info);330serialize_def_levels_simple(validity, length, is_optional, options, buffer)?;331let definition_levels_byte_length = buffer.len();332Ok((0, definition_levels_byte_length))333} else {334nested::write_rep_and_def(options.version, nested, buffer)335}336}337338fn normalized_validity<K: DictionaryKey>(array: &DictionaryArray<K>) -> Option<Bitmap> {339match (array.keys().validity(), array.values().validity()) {340(None, None) => None,341(keys, None) => keys.cloned(),342// The values can have a different length than the keys343(_, Some(_values)) => {344let iter = (0..array.len()).map(|i| unsafe { !array.is_null_unchecked(i) });345MutableBitmap::from_trusted_len_iter(iter).into()346},347}348}349350fn serialize_keys<K: DictionaryKey>(351array: &DictionaryArray<K>,352type_: PrimitiveType,353nested: &[Nested],354statistics: Option<ParquetStatistics>,355options: WriteOptions,356) -> PolarsResult<Page> {357let mut buffer = vec![];358359let (start, len) = slice_nested_leaf(nested);360361let mut nested = nested.to_vec();362let array = array.clone().sliced(start, len);363if let Some(Nested::Primitive(PrimitiveNested { length, .. })) = nested.last_mut() {364*length = len;365} else {366unreachable!("")367}368// Parquet only accepts a single validity - we "&" the validities into a single one369// and ignore keys whose _value_ is null.370// It's important that we slice before normalizing.371let validity = normalized_validity(&array);372373let (repetition_levels_byte_length, definition_levels_byte_length) = serialize_levels(374validity.as_ref(),375array.len(),376&type_,377&nested,378options,379&mut buffer,380)?;381382serialize_keys_values(&array, validity.as_ref(), &mut buffer)?;383384let (num_values, num_rows) = if nested.len() == 1 {385(array.len(), array.len())386} else {387(nested::num_values(&nested), nested[0].len())388};389390utils::build_plain_page(391buffer,392num_values,393num_rows,394array.null_count(),395repetition_levels_byte_length,396definition_levels_byte_length,397statistics,398type_,399options,400Encoding::RleDictionary,401)402.map(Page::Data)403}404405macro_rules! dyn_prim {406($from:ty, $to:ty, $array:expr, $options:expr, $type_:expr) => {{407let values = $array.values().as_any().downcast_ref().unwrap();408409let buffer =410primitive_encode_plain::<$from, $to>(values, EncodeNullability::new(false), vec![]);411412let stats: Option<ParquetStatistics> = if !$options.statistics.is_empty() {413let mut stats = primitive_build_statistics::<$from, $to>(414values,415$type_.clone(),416&$options.statistics,417);418stats.null_count = Some($array.null_count() as i64);419Some(stats.serialize())420} else {421None422};423(424DictPage::new(CowBuffer::Owned(buffer), values.len(), false),425stats,426)427}};428}429430pub fn array_to_pages<K: DictionaryKey>(431array: &DictionaryArray<K>,432type_: PrimitiveType,433nested: &[Nested],434options: WriteOptions,435encoding: Encoding,436) -> PolarsResult<DynIter<'static, PolarsResult<Page>>> {437match encoding {438Encoding::PlainDictionary | Encoding::RleDictionary => {439// write DictPage440let (dict_page, mut statistics): (_, Option<ParquetStatistics>) = match array441.values()442.dtype()443.to_logical_type()444{445ArrowDataType::Int8 => dyn_prim!(i8, i32, array, options, type_),446ArrowDataType::Int16 => dyn_prim!(i16, i32, array, options, type_),447ArrowDataType::Int32 | ArrowDataType::Date32 | ArrowDataType::Time32(_) => {448dyn_prim!(i32, i32, array, options, type_)449},450ArrowDataType::Int64451| ArrowDataType::Date64452| ArrowDataType::Time64(_)453| ArrowDataType::Timestamp(_, _)454| ArrowDataType::Duration(_) => dyn_prim!(i64, i64, array, options, type_),455ArrowDataType::UInt8 => dyn_prim!(u8, i32, array, options, type_),456ArrowDataType::UInt16 => dyn_prim!(u16, i32, array, options, type_),457ArrowDataType::UInt32 => dyn_prim!(u32, i32, array, options, type_),458ArrowDataType::UInt64 => dyn_prim!(u64, i64, array, options, type_),459ArrowDataType::Float32 => dyn_prim!(f32, f32, array, options, type_),460ArrowDataType::Float64 => dyn_prim!(f64, f64, array, options, type_),461ArrowDataType::LargeUtf8 => {462let array = polars_compute::cast::cast(463array.values().as_ref(),464&ArrowDataType::LargeBinary,465Default::default(),466)467.unwrap();468let array = array.as_any().downcast_ref().unwrap();469470let mut buffer = vec![];471binary_encode_plain::<i64>(array, EncodeNullability::Required, &mut buffer);472let stats = if options.has_statistics() {473Some(binary_build_statistics(474array,475type_.clone(),476&options.statistics,477))478} else {479None480};481(482DictPage::new(CowBuffer::Owned(buffer), array.len(), false),483stats,484)485},486ArrowDataType::BinaryView => {487let array = array488.values()489.as_any()490.downcast_ref::<BinaryViewArray>()491.unwrap();492let mut buffer = vec![];493binview::encode_plain(array, EncodeNullability::Required, &mut buffer);494495let stats = if options.has_statistics() {496Some(binview::build_statistics(497array,498type_.clone(),499&options.statistics,500))501} else {502None503};504(505DictPage::new(CowBuffer::Owned(buffer), array.len(), false),506stats,507)508},509ArrowDataType::Utf8View => {510let array = array511.values()512.as_any()513.downcast_ref::<Utf8ViewArray>()514.unwrap()515.to_binview();516let mut buffer = vec![];517binview::encode_plain(&array, EncodeNullability::Required, &mut buffer);518519let stats = if options.has_statistics() {520Some(binview::build_statistics(521&array,522type_.clone(),523&options.statistics,524))525} else {526None527};528(529DictPage::new(CowBuffer::Owned(buffer), array.len(), false),530stats,531)532},533ArrowDataType::LargeBinary => {534let values = array.values().as_any().downcast_ref().unwrap();535536let mut buffer = vec![];537binary_encode_plain::<i64>(values, EncodeNullability::Required, &mut buffer);538let stats = if options.has_statistics() {539Some(binary_build_statistics(540values,541type_.clone(),542&options.statistics,543))544} else {545None546};547(548DictPage::new(CowBuffer::Owned(buffer), values.len(), false),549stats,550)551},552ArrowDataType::FixedSizeBinary(_) => {553let mut buffer = vec![];554let array = array.values().as_any().downcast_ref().unwrap();555fixed_binary_encode_plain(array, EncodeNullability::Required, &mut buffer);556let stats = if options.has_statistics() {557let stats = fixed_binary_build_statistics(558array,559type_.clone(),560&options.statistics,561);562Some(stats.serialize())563} else {564None565};566(567DictPage::new(CowBuffer::Owned(buffer), array.len(), false),568stats,569)570},571other => {572polars_bail!(573nyi =574"Writing dictionary arrays to parquet only support data type {other:?}"575)576},577};578579if let Some(stats) = &mut statistics {580stats.null_count = Some(array.null_count() as i64)581}582583// write DataPage pointing to DictPage584let data_page =585serialize_keys(array, type_, nested, statistics, options)?.unwrap_data();586587Ok(DynIter::new(588[Page::Dict(dict_page), Page::Data(data_page)]589.into_iter()590.map(Ok),591))592},593_ => polars_bail!(nyi = "Dictionary arrays only support dictionary encoding"),594}595}596597598