Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs
6940 views
1
use std::marker::PhantomData;
2
3
use arrow::array::{DictionaryArray, DictionaryKey, MutableBinaryViewArray, PrimitiveArray};
4
use arrow::bitmap::{Bitmap, BitmapBuilder};
5
use arrow::datatypes::ArrowDataType;
6
use polars_utils::vec::with_cast_mut_vec;
7
8
use super::PredicateFilter;
9
use super::binview::BinViewDecoder;
10
use super::utils::{self, Decoder, StateTranslation, dict_indices_decoder, freeze_validity};
11
use crate::parquet::encoding::Encoding;
12
use crate::parquet::encoding::hybrid_rle::HybridRleDecoder;
13
use crate::parquet::error::ParquetResult;
14
use crate::parquet::page::{DataPage, DictPage};
15
use crate::read::deserialize::dictionary_encoded::IndexMapping;
16
17
impl<'a, T: DictionaryKey + IndexMapping<Output = T::AlignedBytes>>
18
StateTranslation<'a, CategoricalDecoder<T>> for HybridRleDecoder<'a>
19
{
20
type PlainDecoder = HybridRleDecoder<'a>;
21
22
fn new(
23
_decoder: &CategoricalDecoder<T>,
24
page: &'a DataPage,
25
_dict: Option<&'a <CategoricalDecoder<T> as Decoder>::Dict>,
26
page_validity: Option<&Bitmap>,
27
) -> ParquetResult<Self> {
28
if !matches!(
29
page.encoding(),
30
Encoding::PlainDictionary | Encoding::RleDictionary
31
) {
32
return Err(utils::not_implemented(page));
33
}
34
35
dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits()))
36
}
37
fn num_rows(&self) -> usize {
38
self.len()
39
}
40
}
41
42
/// Special decoder for Polars Enum and Categorical's.
43
///
44
/// These are marked as special in the Arrow Field Metadata and they have the properly that for a
45
/// given row group all the values are in the dictionary page and all data pages are dictionary
46
/// encoded. This makes the job of decoding them extremely simple and fast.
47
pub struct CategoricalDecoder<T> {
48
dict_size: usize,
49
decoder: BinViewDecoder,
50
key_type: PhantomData<T>,
51
}
52
53
impl<T> CategoricalDecoder<T> {
54
pub fn new() -> Self {
55
Self {
56
dict_size: usize::MAX,
57
decoder: BinViewDecoder::new_string(),
58
key_type: PhantomData,
59
}
60
}
61
}
62
63
impl<T: DictionaryKey + IndexMapping<Output = T::AlignedBytes>> utils::Decoder
64
for CategoricalDecoder<T>
65
{
66
type Translation<'a> = HybridRleDecoder<'a>;
67
type Dict = <BinViewDecoder as utils::Decoder>::Dict;
68
type DecodedState = (Vec<T>, BitmapBuilder);
69
type Output = DictionaryArray<T>;
70
71
fn with_capacity(&self, capacity: usize) -> Self::DecodedState {
72
(
73
Vec::<T>::with_capacity(capacity),
74
BitmapBuilder::with_capacity(capacity),
75
)
76
}
77
78
fn has_predicate_specialization(
79
&self,
80
state: &utils::State<'_, Self>,
81
_predicate: &PredicateFilter,
82
) -> ParquetResult<bool> {
83
Ok(state.page_validity.is_none())
84
}
85
86
fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult<Self::Dict> {
87
let dict = self.decoder.deserialize_dict(page)?;
88
self.dict_size = dict.len();
89
Ok(dict)
90
}
91
92
fn extend_decoded(
93
&self,
94
decoded: &mut Self::DecodedState,
95
additional: &dyn arrow::array::Array,
96
is_optional: bool,
97
) -> ParquetResult<()> {
98
let additional = additional
99
.as_any()
100
.downcast_ref::<DictionaryArray<T>>()
101
.unwrap();
102
decoded.0.extend(additional.keys().values().iter().copied());
103
match additional.validity() {
104
Some(v) => decoded.1.extend_from_bitmap(v),
105
None if is_optional => decoded.1.extend_constant(additional.len(), true),
106
None => {},
107
}
108
109
Ok(())
110
}
111
112
fn finalize(
113
&self,
114
dtype: ArrowDataType,
115
dict: Option<Self::Dict>,
116
(values, validity): Self::DecodedState,
117
) -> ParquetResult<DictionaryArray<T>> {
118
let validity = freeze_validity(validity);
119
let dict = dict.unwrap();
120
let keys = PrimitiveArray::new(T::PRIMITIVE.into(), values.into(), validity);
121
122
let mut view_dict = MutableBinaryViewArray::with_capacity(dict.len());
123
let (views, buffers, _, _, _) = dict.into_inner();
124
125
for buffer in buffers.iter() {
126
view_dict.push_buffer(buffer.clone());
127
}
128
unsafe { view_dict.views_mut().extend(views.iter()) };
129
unsafe { view_dict.set_total_bytes_len(views.iter().map(|v| v.length as usize).sum()) };
130
let view_dict = view_dict.freeze();
131
132
// SAFETY: This was checked during construction of the dictionary
133
let dict = unsafe { view_dict.to_utf8view_unchecked() }.boxed();
134
135
// SAFETY: This was checked during decoding
136
Ok(unsafe { DictionaryArray::try_new_unchecked(dtype, keys, dict) }.unwrap())
137
}
138
139
fn extend_filtered_with_state(
140
&mut self,
141
state: utils::State<'_, Self>,
142
decoded: &mut Self::DecodedState,
143
pred_true_mask: &mut BitmapBuilder,
144
filter: Option<super::Filter>,
145
) -> ParquetResult<()> {
146
with_cast_mut_vec::<T, T::AlignedBytes, _, _>(&mut decoded.0, |aligned_bytes_vec| {
147
super::dictionary_encoded::decode_dict_dispatch(
148
state.translation,
149
T::try_from(self.dict_size).ok().unwrap(),
150
state.dict_mask,
151
state.is_optional,
152
state.page_validity.as_ref(),
153
filter,
154
&mut decoded.1,
155
aligned_bytes_vec,
156
pred_true_mask,
157
)
158
})
159
}
160
}
161
162