Path: blob/main/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs
8512 views
use std::marker::PhantomData;12use arrow::array::{DictionaryArray, DictionaryKey, MutableBinaryViewArray, PrimitiveArray};3use arrow::bitmap::{Bitmap, BitmapBuilder};4use arrow::datatypes::ArrowDataType;5use polars_utils::vec::with_cast_mut_vec;67use super::binview::BinViewDecoder;8use super::utils::{self, Decoder, StateTranslation, dict_indices_decoder, freeze_validity};9use crate::parquet::encoding::Encoding;10use crate::parquet::encoding::hybrid_rle::HybridRleDecoder;11use crate::parquet::error::ParquetResult;12use crate::parquet::page::{DataPage, DictPage};13use crate::read::ParquetError;14use crate::read::deserialize::dictionary_encoded::IndexMapping;15use crate::read::expr::SpecializedParquetColumnExpr;1617impl<'a, T: DictionaryKey + IndexMapping<Output = T::AlignedBytes>>18StateTranslation<'a, CategoricalDecoder<T>> for HybridRleDecoder<'a>19{20type PlainDecoder = HybridRleDecoder<'a>;2122fn new(23_decoder: &CategoricalDecoder<T>,24page: &'a DataPage,25_dict: Option<&'a <CategoricalDecoder<T> as Decoder>::Dict>,26page_validity: Option<&Bitmap>,27) -> ParquetResult<Self> {28if !matches!(29page.encoding(),30Encoding::PlainDictionary | Encoding::RleDictionary31) {32return Err(utils::not_implemented(page));33}3435dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits()))36}37fn num_rows(&self) -> usize {38self.len()39}40}4142/// Special decoder for Polars Enum and Categorical's.43///44/// These are marked as special in the Arrow Field Metadata and they have the properly that for a45/// given row group all the values are in the dictionary page and all data pages are dictionary46/// encoded. This makes the job of decoding them extremely simple and fast.47pub struct CategoricalDecoder<T> {48dict_size: usize,49decoder: BinViewDecoder,50key_type: PhantomData<T>,51}5253impl<T> CategoricalDecoder<T> {54pub fn new() -> Self {55Self {56dict_size: usize::MAX,57decoder: BinViewDecoder::new_string(),58key_type: PhantomData,59}60}61}6263impl<T: DictionaryKey + IndexMapping<Output = T::AlignedBytes>> utils::Decoder64for CategoricalDecoder<T>65{66type Translation<'a> = HybridRleDecoder<'a>;67type Dict = <BinViewDecoder as utils::Decoder>::Dict;68type DecodedState = (Vec<T>, BitmapBuilder);69type Output = DictionaryArray<T>;7071fn with_capacity(&self, capacity: usize) -> Self::DecodedState {72(73Vec::<T>::with_capacity(capacity),74BitmapBuilder::with_capacity(capacity),75)76}7778fn evaluate_predicate(79&mut self,80state: &utils::State<'_, Self>,81_predicate: Option<&SpecializedParquetColumnExpr>,82pred_true_mask: &mut BitmapBuilder,83dict_mask: Option<&Bitmap>,84) -> ParquetResult<bool> {85if state.page_validity.is_some() {86// @Performance: implement validity aware87return Ok(false);88}8990let dict_mask = dict_mask.unwrap();91super::dictionary_encoded::predicate::decode(92state.translation.clone(),93dict_mask,94pred_true_mask,95)?;9697Ok(true)98}99100fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult<Self::Dict> {101let dict = self.decoder.deserialize_dict(page)?;102self.dict_size = dict.len();103Ok(dict)104}105106fn extend_decoded(107&self,108decoded: &mut Self::DecodedState,109additional: &dyn arrow::array::Array,110is_optional: bool,111) -> ParquetResult<()> {112let additional = additional113.as_any()114.downcast_ref::<DictionaryArray<T>>()115.unwrap();116decoded.0.extend(additional.keys().values().iter().copied());117match additional.validity() {118Some(v) => decoded.1.extend_from_bitmap(v),119None if is_optional => decoded.1.extend_constant(additional.len(), true),120None => {},121}122123Ok(())124}125126fn finalize(127&self,128dtype: ArrowDataType,129dict: Option<Self::Dict>,130(values, validity): Self::DecodedState,131) -> ParquetResult<DictionaryArray<T>> {132let validity = freeze_validity(validity);133let dict = dict.unwrap();134let keys = PrimitiveArray::new(T::PRIMITIVE.into(), values.into(), validity);135136let mut view_dict = MutableBinaryViewArray::with_capacity(dict.len());137let (views, buffers, _, _, _) = dict.into_inner();138139for buffer in buffers.iter() {140view_dict.push_buffer(buffer.clone());141}142unsafe { view_dict.views_mut().extend(views.iter()) };143unsafe { view_dict.set_total_bytes_len(views.iter().map(|v| v.length as usize).sum()) };144let view_dict = view_dict.freeze();145146// SAFETY: This was checked during construction of the dictionary147let dict = unsafe { view_dict.to_utf8view_unchecked() }.boxed();148149// SAFETY: This was checked during decoding150Ok(unsafe { DictionaryArray::try_new_unchecked(dtype, keys, dict) }.unwrap())151}152153fn extend_filtered_with_state(154&mut self,155state: utils::State<'_, Self>,156decoded: &mut Self::DecodedState,157filter: Option<super::Filter>,158_chunks: &mut Vec<Self::Output>,159) -> ParquetResult<()> {160with_cast_mut_vec::<T, T::AlignedBytes, _, _>(&mut decoded.0, |aligned_bytes_vec| {161super::dictionary_encoded::decode_dict_dispatch(162state.translation,163T::try_from(self.dict_size).ok().unwrap(),164state.is_optional,165state.page_validity.as_ref(),166filter,167&mut decoded.1,168aligned_bytes_vec,169)170})171}172173fn extend_constant(174&mut self,175_decoded: &mut Self::DecodedState,176_length: usize,177_value: &crate::read::expr::ParquetScalar,178) -> ParquetResult<()> {179Err(ParquetError::not_supported(180"categorical with pushed-down equality filter",181))182}183}184185186