Path: blob/main/crates/polars-parquet/src/arrow/write/dictionary.rs
8475 views
use arrow::array::{1Array, BinaryViewArray, DictionaryArray, DictionaryKey, PrimitiveArray, Utf8ViewArray,2};3use arrow::bitmap::{Bitmap, MutableBitmap};4use arrow::compute::aggregate::estimated_bytes_size;5use arrow::datatypes::{ArrowDataType, IntegerType, PhysicalType};6use arrow::legacy::utils::CustomIterTools;7use arrow::trusted_len::TrustMyLength;8use arrow::types::NativeType;9use polars_buffer::Buffer;10use polars_compute::min_max::MinMaxKernel;11use polars_error::{PolarsResult, polars_bail};12use polars_utils::float16::pf16;1314use super::binary::{15build_statistics as binary_build_statistics, encode_plain as binary_encode_plain,16};17use super::fixed_size_binary::{18build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain,19};20use super::primitive::{21build_statistics as primitive_build_statistics, encode_plain as primitive_encode_plain,22};23use super::{24EncodeNullability, Nested, WriteOptions, binview, nested, row_slice_ranges, slice_parquet_array,25};26use crate::arrow::read::schema::is_nullable;27use crate::arrow::write::utils;28use crate::parquet::CowBuffer;29use crate::parquet::encoding::Encoding;30use crate::parquet::encoding::hybrid_rle::encode;31use crate::parquet::page::{DictPage, Page};32use crate::parquet::schema::types::PrimitiveType;33use crate::parquet::statistics::ParquetStatistics;34use crate::write::DynIter;3536trait MinMaxThreshold {37const DELTA_THRESHOLD: usize;38const BITMASK_THRESHOLD: usize;3940fn from_start_and_offset(start: Self, offset: usize) -> Self;41}4243macro_rules! minmaxthreshold_impls {44($($signed:ty, $unsigned:ty => $threshold:literal, $bm_threshold:expr,)+) => {45$(46impl MinMaxThreshold for $signed {47const DELTA_THRESHOLD: usize = $threshold;48const BITMASK_THRESHOLD: usize = $bm_threshold;4950fn from_start_and_offset(start: Self, offset: usize) -> Self {51start + ((offset as $unsigned) as $signed)52}53}54impl MinMaxThreshold for $unsigned {55const DELTA_THRESHOLD: usize = $threshold;56const BITMASK_THRESHOLD: usize = $bm_threshold;5758fn from_start_and_offset(start: Self, offset: usize) -> Self {59start + (offset as $unsigned)60}61}62)+63};64}6566minmaxthreshold_impls! {67i8, u8 => 16, u8::MAX as usize,68i16, u16 => 256, u16::MAX as usize,69i32, u32 => 512, u16::MAX as usize,70i64, u64 => 2048, u16::MAX as usize,71}7273enum DictionaryDecision {74NotWorth,75TryAgain,76Found(DictionaryArray<u32>),77}7879fn min_max_integer_encode_as_dictionary_optional<'a, E, T>(80array: &'a dyn Array,81) -> DictionaryDecision82where83E: std::fmt::Debug,84T: NativeType85+ MinMaxThreshold86+ std::cmp::Ord87+ TryInto<u32, Error = E>88+ std::ops::Sub<T, Output = T>89+ num_traits::CheckedSub90+ num_traits::cast::AsPrimitive<usize>,91std::ops::RangeInclusive<T>: Iterator<Item = T>,92PrimitiveArray<T>: MinMaxKernel<Scalar<'a> = T>,93{94let min_max = <PrimitiveArray<T> as MinMaxKernel>::min_max_ignore_nan_kernel(95array.as_any().downcast_ref().unwrap(),96);9798let Some((min, max)) = min_max else {99return DictionaryDecision::TryAgain;100};101102debug_assert!(max >= min, "{max} >= {min}");103let Some(diff) = max.checked_sub(&min) else {104return DictionaryDecision::TryAgain;105};106107let diff = diff.as_();108109if diff > T::BITMASK_THRESHOLD {110return DictionaryDecision::TryAgain;111}112113let mut seen_mask = MutableBitmap::from_len_zeroed(diff + 1);114115let array = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();116117if array.has_nulls() {118for v in array.non_null_values_iter() {119let offset = (v - min).as_();120debug_assert!(offset <= diff);121122unsafe {123seen_mask.set_unchecked(offset, true);124}125}126} else {127for v in array.values_iter() {128let offset = (*v - min).as_();129debug_assert!(offset <= diff);130131unsafe {132seen_mask.set_unchecked(offset, true);133}134}135}136137let cardinality = seen_mask.set_bits();138139let mut is_worth_it = false;140141is_worth_it |= cardinality <= T::DELTA_THRESHOLD;142is_worth_it |= (cardinality as f64) / (array.len() as f64) < 0.75;143144if !is_worth_it {145return DictionaryDecision::NotWorth;146}147148let seen_mask = seen_mask.freeze();149150// SAFETY: We just did the calculation for this.151let indexes = seen_mask152.true_idx_iter()153.map(|idx| T::from_start_and_offset(min, idx));154let indexes = unsafe { TrustMyLength::new(indexes, cardinality) };155let indexes = indexes.collect_trusted::<Vec<_>>();156157let mut lookup = vec![0u16; diff + 1];158159for (i, &idx) in indexes.iter().enumerate() {160lookup[(idx - min).as_()] = i as u16;161}162163use ArrowDataType as DT;164let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), indexes.into(), None);165let values = Box::new(values);166167let keys: Buffer<u32> = array168.as_any()169.downcast_ref::<PrimitiveArray<T>>()170.unwrap()171.values()172.iter()173.map(|v| {174// @NOTE:175// Since the values might contain nulls which have a undefined value. We just176// clamp the values to between the min and max value. This way, they will still177// be valid dictionary keys.178let idx = *v.clamp(&min, &max) - min;179let value = unsafe { lookup.get_unchecked(idx.as_()) };180(*value).into()181})182.collect();183184let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned());185DictionaryDecision::Found(186DictionaryArray::<u32>::try_new(187ArrowDataType::Dictionary(188IntegerType::UInt32,189Box::new(DT::from(T::PRIMITIVE)),190false, // @TODO: This might be able to be set to true?191),192keys,193values,194)195.unwrap(),196)197}198199pub(crate) fn encode_as_dictionary_optional(200array: &dyn Array,201nested: &[Nested],202type_: PrimitiveType,203options: WriteOptions,204) -> Option<PolarsResult<DynIter<'static, PolarsResult<Page>>>> {205if array.is_empty() {206let array = DictionaryArray::<u32>::new_empty(ArrowDataType::Dictionary(207IntegerType::UInt32,208Box::new(array.dtype().clone()),209false, // @TODO: This might be able to be set to true?210));211212return Some(array_to_pages(213&array,214type_,215nested,216options,217Encoding::RleDictionary,218));219}220221use arrow::types::PrimitiveType as PT;222let fast_dictionary = match array.dtype().to_physical_type() {223PhysicalType::Primitive(pt) => match pt {224PT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array),225PT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array),226PT::Int32 => min_max_integer_encode_as_dictionary_optional::<_, i32>(array),227PT::Int64 => min_max_integer_encode_as_dictionary_optional::<_, i64>(array),228PT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array),229PT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array),230PT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array),231PT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array),232_ => DictionaryDecision::TryAgain,233},234_ => DictionaryDecision::TryAgain,235};236237match fast_dictionary {238DictionaryDecision::NotWorth => return None,239DictionaryDecision::Found(dictionary_array) => {240return Some(array_to_pages(241&dictionary_array,242type_,243nested,244options,245Encoding::RleDictionary,246));247},248DictionaryDecision::TryAgain => {},249}250251let dtype = Box::new(array.dtype().clone());252253let estimated_cardinality = polars_compute::cardinality::estimate_cardinality(array);254255if array.len() > 128 && (estimated_cardinality as f64) / (array.len() as f64) > 0.75 {256return None;257}258259// This does the group by.260let array = polars_compute::cast::cast(261array,262&ArrowDataType::Dictionary(IntegerType::UInt32, dtype, false),263Default::default(),264)265.ok()?;266267let array = array268.as_any()269.downcast_ref::<DictionaryArray<u32>>()270.unwrap();271272Some(array_to_pages(273array,274type_,275nested,276options,277Encoding::RleDictionary,278))279}280281fn serialize_def_levels_simple(282validity: Option<&Bitmap>,283length: usize,284is_optional: bool,285options: WriteOptions,286buffer: &mut Vec<u8>,287) -> PolarsResult<()> {288utils::write_def_levels(buffer, is_optional, validity, length, options.version)289}290291fn serialize_keys_values<K: DictionaryKey>(292array: &DictionaryArray<K>,293validity: Option<&Bitmap>,294buffer: &mut Vec<u8>,295) -> PolarsResult<()> {296let keys = array.keys_values_iter().map(|x| x as u32);297if let Some(validity) = validity {298// discard indices whose values are null.299let keys = keys300.zip(validity.iter())301.filter(|&(_key, is_valid)| is_valid)302.map(|(key, _is_valid)| key);303let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64);304305let keys = utils::ExactSizedIter::new(keys, array.len() - validity.unset_bits());306307// num_bits as a single byte308buffer.push(num_bits as u8);309310// followed by the encoded indices.311Ok(encode::<u32, _, _>(buffer, keys, num_bits)?)312} else {313let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64);314315// num_bits as a single byte316buffer.push(num_bits as u8);317318// followed by the encoded indices.319Ok(encode::<u32, _, _>(buffer, keys, num_bits)?)320}321}322323fn serialize_levels(324validity: Option<&Bitmap>,325length: usize,326type_: &PrimitiveType,327nested: &[Nested],328options: WriteOptions,329buffer: &mut Vec<u8>,330) -> PolarsResult<(usize, usize)> {331if nested.len() == 1 {332let is_optional = is_nullable(&type_.field_info);333serialize_def_levels_simple(validity, length, is_optional, options, buffer)?;334let definition_levels_byte_length = buffer.len();335Ok((0, definition_levels_byte_length))336} else {337nested::write_rep_and_def(options.version, nested, buffer)338}339}340341fn normalized_validity<K: DictionaryKey>(array: &DictionaryArray<K>) -> Option<Bitmap> {342match (array.keys().validity(), array.values().validity()) {343(None, None) => None,344(keys, None) => keys.cloned(),345// The values can have a different length than the keys346(_, Some(_values)) => {347let iter = (0..array.len()).map(|i| unsafe { !array.is_null_unchecked(i) });348MutableBitmap::from_trusted_len_iter(iter).into()349},350}351}352353fn serialize_keys<K: DictionaryKey>(354array: &DictionaryArray<K>,355type_: PrimitiveType,356nested: &[Nested],357statistics: Option<ParquetStatistics>,358options: WriteOptions,359) -> DynIter<'static, PolarsResult<Page>> {360let number_of_rows = nested[0].len();361let byte_size = estimated_bytes_size(array.keys());362363let array = array.clone();364let nested = nested.to_vec();365366let pages =367row_slice_ranges(number_of_rows, byte_size, options).map(move |(offset, length)| {368let mut sliced_array = array.clone();369let mut sliced_nested = nested.clone();370slice_parquet_array(&mut sliced_array, &mut sliced_nested, offset, length);371372serialize_keys_range(373&sliced_array,374&type_,375&sliced_nested,376statistics.clone(),377options,378)379});380381DynIter::new(pages)382}383384fn serialize_keys_range<K: DictionaryKey>(385array: &DictionaryArray<K>,386type_: &PrimitiveType,387nested: &[Nested],388statistics: Option<ParquetStatistics>,389options: WriteOptions,390) -> PolarsResult<Page> {391let mut buffer = vec![];392393// Parquet only accepts a single validity - we "&" the validities into a single one394// and ignore keys whose _value_ is null.395let validity = normalized_validity(array);396397let (repetition_levels_byte_length, definition_levels_byte_length) = serialize_levels(398validity.as_ref(),399array.len(),400type_,401nested,402options,403&mut buffer,404)?;405406serialize_keys_values(array, validity.as_ref(), &mut buffer)?;407408let (num_values, num_rows) = if nested.len() == 1 {409(array.len(), array.len())410} else {411(nested::num_values(nested), nested[0].len())412};413414utils::build_plain_page(415buffer,416num_values,417num_rows,418array.null_count(),419repetition_levels_byte_length,420definition_levels_byte_length,421statistics,422type_.clone(),423options,424Encoding::RleDictionary,425)426.map(Page::Data)427}428429macro_rules! dyn_prim {430($from:ty, $to:ty, $array:expr, $options:expr, $type_:expr) => {{431let values = $array.values().as_any().downcast_ref().unwrap();432433let buffer =434primitive_encode_plain::<$from, $to>(values, EncodeNullability::new(false), vec![]);435436let stats: Option<ParquetStatistics> = if !$options.statistics.is_empty() {437let mut stats = primitive_build_statistics::<$from, $to>(438values,439$type_.clone(),440&$options.statistics,441);442stats.null_count = Some($array.null_count() as i64);443Some(stats.serialize())444} else {445None446};447(448DictPage::new(CowBuffer::Owned(buffer), values.len(), false),449stats,450)451}};452}453454pub fn array_to_pages<K: DictionaryKey>(455array: &DictionaryArray<K>,456type_: PrimitiveType,457nested: &[Nested],458options: WriteOptions,459encoding: Encoding,460) -> PolarsResult<DynIter<'static, PolarsResult<Page>>> {461match encoding {462Encoding::PlainDictionary | Encoding::RleDictionary => {463// write DictPage464let (dict_page, mut statistics): (_, Option<ParquetStatistics>) = match array465.values()466.dtype()467.to_storage()468{469ArrowDataType::Int8 => dyn_prim!(i8, i32, array, options, type_),470ArrowDataType::Int16 => dyn_prim!(i16, i32, array, options, type_),471ArrowDataType::Int32 | ArrowDataType::Date32 | ArrowDataType::Time32(_) => {472dyn_prim!(i32, i32, array, options, type_)473},474ArrowDataType::Int64475| ArrowDataType::Date64476| ArrowDataType::Time64(_)477| ArrowDataType::Timestamp(_, _)478| ArrowDataType::Duration(_) => dyn_prim!(i64, i64, array, options, type_),479ArrowDataType::UInt8 => dyn_prim!(u8, i32, array, options, type_),480ArrowDataType::UInt16 => dyn_prim!(u16, i32, array, options, type_),481ArrowDataType::UInt32 => dyn_prim!(u32, i32, array, options, type_),482ArrowDataType::UInt64 => dyn_prim!(u64, i64, array, options, type_),483ArrowDataType::Float16 => dyn_prim!(pf16, f32, array, options, type_),484ArrowDataType::Float32 => dyn_prim!(f32, f32, array, options, type_),485ArrowDataType::Float64 => dyn_prim!(f64, f64, array, options, type_),486ArrowDataType::LargeUtf8 => {487let array = polars_compute::cast::cast(488array.values().as_ref(),489&ArrowDataType::LargeBinary,490Default::default(),491)492.unwrap();493let array = array.as_any().downcast_ref().unwrap();494495let mut buffer = vec![];496binary_encode_plain::<i64>(array, EncodeNullability::Required, &mut buffer);497let stats = if options.has_statistics() {498Some(binary_build_statistics(499array,500type_.clone(),501&options.statistics,502))503} else {504None505};506(507DictPage::new(CowBuffer::Owned(buffer), array.len(), false),508stats,509)510},511ArrowDataType::BinaryView => {512let array = array513.values()514.as_any()515.downcast_ref::<BinaryViewArray>()516.unwrap();517let mut buffer = vec![];518binview::encode_plain(array, EncodeNullability::Required, &mut buffer);519520let stats = if options.has_statistics() {521Some(binview::build_statistics(522array,523type_.clone(),524&options.statistics,525))526} else {527None528};529(530DictPage::new(CowBuffer::Owned(buffer), array.len(), false),531stats,532)533},534ArrowDataType::Utf8View => {535let array = array536.values()537.as_any()538.downcast_ref::<Utf8ViewArray>()539.unwrap()540.to_binview();541let mut buffer = vec![];542binview::encode_plain(&array, EncodeNullability::Required, &mut buffer);543544let stats = if options.has_statistics() {545Some(binview::build_statistics(546&array,547type_.clone(),548&options.statistics,549))550} else {551None552};553(554DictPage::new(CowBuffer::Owned(buffer), array.len(), false),555stats,556)557},558ArrowDataType::LargeBinary => {559let values = array.values().as_any().downcast_ref().unwrap();560561let mut buffer = vec![];562binary_encode_plain::<i64>(values, EncodeNullability::Required, &mut buffer);563let stats = if options.has_statistics() {564Some(binary_build_statistics(565values,566type_.clone(),567&options.statistics,568))569} else {570None571};572(573DictPage::new(CowBuffer::Owned(buffer), values.len(), false),574stats,575)576},577ArrowDataType::FixedSizeBinary(_) => {578let mut buffer = vec![];579let array = array.values().as_any().downcast_ref().unwrap();580fixed_binary_encode_plain(array, EncodeNullability::Required, &mut buffer);581let stats = if options.has_statistics() {582let stats = fixed_binary_build_statistics(583array,584type_.clone(),585&options.statistics,586);587Some(stats.serialize())588} else {589None590};591(592DictPage::new(CowBuffer::Owned(buffer), array.len(), false),593stats,594)595},596other => {597polars_bail!(598nyi =599"Writing dictionary arrays to parquet only support data type {other:?}"600)601},602};603604if let Some(stats) = &mut statistics {605stats.null_count = Some(array.null_count() as i64)606}607608// write DataPages pointing to DictPage609let data_pages = serialize_keys(array, type_, nested, statistics, options);610611Ok(DynIter::new(612std::iter::once(Ok(Page::Dict(dict_page))).chain(data_pages),613))614},615_ => polars_bail!(nyi = "Dictionary arrays only support dictionary encoding"),616}617}618619620