Path: blob/main/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs
6940 views
use std::marker::PhantomData;12use arrow::array::{DictionaryArray, DictionaryKey, MutableBinaryViewArray, PrimitiveArray};3use arrow::bitmap::{Bitmap, BitmapBuilder};4use arrow::datatypes::ArrowDataType;5use polars_utils::vec::with_cast_mut_vec;67use super::PredicateFilter;8use super::binview::BinViewDecoder;9use super::utils::{self, Decoder, StateTranslation, dict_indices_decoder, freeze_validity};10use crate::parquet::encoding::Encoding;11use crate::parquet::encoding::hybrid_rle::HybridRleDecoder;12use crate::parquet::error::ParquetResult;13use crate::parquet::page::{DataPage, DictPage};14use crate::read::deserialize::dictionary_encoded::IndexMapping;1516impl<'a, T: DictionaryKey + IndexMapping<Output = T::AlignedBytes>>17StateTranslation<'a, CategoricalDecoder<T>> for HybridRleDecoder<'a>18{19type PlainDecoder = HybridRleDecoder<'a>;2021fn new(22_decoder: &CategoricalDecoder<T>,23page: &'a DataPage,24_dict: Option<&'a <CategoricalDecoder<T> as Decoder>::Dict>,25page_validity: Option<&Bitmap>,26) -> ParquetResult<Self> {27if !matches!(28page.encoding(),29Encoding::PlainDictionary | Encoding::RleDictionary30) {31return Err(utils::not_implemented(page));32}3334dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits()))35}36fn num_rows(&self) -> usize {37self.len()38}39}4041/// Special decoder for Polars Enum and Categorical's.42///43/// These are marked as special in the Arrow Field Metadata and they have the properly that for a44/// given row group all the values are in the dictionary page and all data pages are dictionary45/// encoded. This makes the job of decoding them extremely simple and fast.46pub struct CategoricalDecoder<T> {47dict_size: usize,48decoder: BinViewDecoder,49key_type: PhantomData<T>,50}5152impl<T> CategoricalDecoder<T> {53pub fn new() -> Self {54Self {55dict_size: usize::MAX,56decoder: BinViewDecoder::new_string(),57key_type: PhantomData,58}59}60}6162impl<T: DictionaryKey + IndexMapping<Output = T::AlignedBytes>> utils::Decoder63for CategoricalDecoder<T>64{65type Translation<'a> = HybridRleDecoder<'a>;66type Dict = <BinViewDecoder as utils::Decoder>::Dict;67type DecodedState = (Vec<T>, BitmapBuilder);68type Output = DictionaryArray<T>;6970fn with_capacity(&self, capacity: usize) -> Self::DecodedState {71(72Vec::<T>::with_capacity(capacity),73BitmapBuilder::with_capacity(capacity),74)75}7677fn has_predicate_specialization(78&self,79state: &utils::State<'_, Self>,80_predicate: &PredicateFilter,81) -> ParquetResult<bool> {82Ok(state.page_validity.is_none())83}8485fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult<Self::Dict> {86let dict = self.decoder.deserialize_dict(page)?;87self.dict_size = dict.len();88Ok(dict)89}9091fn extend_decoded(92&self,93decoded: &mut Self::DecodedState,94additional: &dyn arrow::array::Array,95is_optional: bool,96) -> ParquetResult<()> {97let additional = additional98.as_any()99.downcast_ref::<DictionaryArray<T>>()100.unwrap();101decoded.0.extend(additional.keys().values().iter().copied());102match additional.validity() {103Some(v) => decoded.1.extend_from_bitmap(v),104None if is_optional => decoded.1.extend_constant(additional.len(), true),105None => {},106}107108Ok(())109}110111fn finalize(112&self,113dtype: ArrowDataType,114dict: Option<Self::Dict>,115(values, validity): Self::DecodedState,116) -> ParquetResult<DictionaryArray<T>> {117let validity = freeze_validity(validity);118let dict = dict.unwrap();119let keys = PrimitiveArray::new(T::PRIMITIVE.into(), values.into(), validity);120121let mut view_dict = MutableBinaryViewArray::with_capacity(dict.len());122let (views, buffers, _, _, _) = dict.into_inner();123124for buffer in buffers.iter() {125view_dict.push_buffer(buffer.clone());126}127unsafe { view_dict.views_mut().extend(views.iter()) };128unsafe { view_dict.set_total_bytes_len(views.iter().map(|v| v.length as usize).sum()) };129let view_dict = view_dict.freeze();130131// SAFETY: This was checked during construction of the dictionary132let dict = unsafe { view_dict.to_utf8view_unchecked() }.boxed();133134// SAFETY: This was checked during decoding135Ok(unsafe { DictionaryArray::try_new_unchecked(dtype, keys, dict) }.unwrap())136}137138fn extend_filtered_with_state(139&mut self,140state: utils::State<'_, Self>,141decoded: &mut Self::DecodedState,142pred_true_mask: &mut BitmapBuilder,143filter: Option<super::Filter>,144) -> ParquetResult<()> {145with_cast_mut_vec::<T, T::AlignedBytes, _, _>(&mut decoded.0, |aligned_bytes_vec| {146super::dictionary_encoded::decode_dict_dispatch(147state.translation,148T::try_from(self.dict_size).ok().unwrap(),149state.dict_mask,150state.is_optional,151state.page_validity.as_ref(),152filter,153&mut decoded.1,154aligned_bytes_vec,155pred_true_mask,156)157})158}159}160161162