Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-parquet/src/arrow/read/deserialize/categorical.rs
8512 views
1
use std::marker::PhantomData;
2
3
use arrow::array::{DictionaryArray, DictionaryKey, MutableBinaryViewArray, PrimitiveArray};
4
use arrow::bitmap::{Bitmap, BitmapBuilder};
5
use arrow::datatypes::ArrowDataType;
6
use polars_utils::vec::with_cast_mut_vec;
7
8
use super::binview::BinViewDecoder;
9
use super::utils::{self, Decoder, StateTranslation, dict_indices_decoder, freeze_validity};
10
use crate::parquet::encoding::Encoding;
11
use crate::parquet::encoding::hybrid_rle::HybridRleDecoder;
12
use crate::parquet::error::ParquetResult;
13
use crate::parquet::page::{DataPage, DictPage};
14
use crate::read::ParquetError;
15
use crate::read::deserialize::dictionary_encoded::IndexMapping;
16
use crate::read::expr::SpecializedParquetColumnExpr;
17
18
impl<'a, T: DictionaryKey + IndexMapping<Output = T::AlignedBytes>>
19
StateTranslation<'a, CategoricalDecoder<T>> for HybridRleDecoder<'a>
20
{
21
type PlainDecoder = HybridRleDecoder<'a>;
22
23
fn new(
24
_decoder: &CategoricalDecoder<T>,
25
page: &'a DataPage,
26
_dict: Option<&'a <CategoricalDecoder<T> as Decoder>::Dict>,
27
page_validity: Option<&Bitmap>,
28
) -> ParquetResult<Self> {
29
if !matches!(
30
page.encoding(),
31
Encoding::PlainDictionary | Encoding::RleDictionary
32
) {
33
return Err(utils::not_implemented(page));
34
}
35
36
dict_indices_decoder(page, page_validity.map_or(0, |bm| bm.unset_bits()))
37
}
38
fn num_rows(&self) -> usize {
39
self.len()
40
}
41
}
42
43
/// Special decoder for Polars Enum and Categorical's.
44
///
45
/// These are marked as special in the Arrow Field Metadata and they have the properly that for a
46
/// given row group all the values are in the dictionary page and all data pages are dictionary
47
/// encoded. This makes the job of decoding them extremely simple and fast.
48
pub struct CategoricalDecoder<T> {
49
dict_size: usize,
50
decoder: BinViewDecoder,
51
key_type: PhantomData<T>,
52
}
53
54
impl<T> CategoricalDecoder<T> {
55
pub fn new() -> Self {
56
Self {
57
dict_size: usize::MAX,
58
decoder: BinViewDecoder::new_string(),
59
key_type: PhantomData,
60
}
61
}
62
}
63
64
impl<T: DictionaryKey + IndexMapping<Output = T::AlignedBytes>> utils::Decoder
65
for CategoricalDecoder<T>
66
{
67
type Translation<'a> = HybridRleDecoder<'a>;
68
type Dict = <BinViewDecoder as utils::Decoder>::Dict;
69
type DecodedState = (Vec<T>, BitmapBuilder);
70
type Output = DictionaryArray<T>;
71
72
fn with_capacity(&self, capacity: usize) -> Self::DecodedState {
73
(
74
Vec::<T>::with_capacity(capacity),
75
BitmapBuilder::with_capacity(capacity),
76
)
77
}
78
79
fn evaluate_predicate(
80
&mut self,
81
state: &utils::State<'_, Self>,
82
_predicate: Option<&SpecializedParquetColumnExpr>,
83
pred_true_mask: &mut BitmapBuilder,
84
dict_mask: Option<&Bitmap>,
85
) -> ParquetResult<bool> {
86
if state.page_validity.is_some() {
87
// @Performance: implement validity aware
88
return Ok(false);
89
}
90
91
let dict_mask = dict_mask.unwrap();
92
super::dictionary_encoded::predicate::decode(
93
state.translation.clone(),
94
dict_mask,
95
pred_true_mask,
96
)?;
97
98
Ok(true)
99
}
100
101
fn deserialize_dict(&mut self, page: DictPage) -> ParquetResult<Self::Dict> {
102
let dict = self.decoder.deserialize_dict(page)?;
103
self.dict_size = dict.len();
104
Ok(dict)
105
}
106
107
fn extend_decoded(
108
&self,
109
decoded: &mut Self::DecodedState,
110
additional: &dyn arrow::array::Array,
111
is_optional: bool,
112
) -> ParquetResult<()> {
113
let additional = additional
114
.as_any()
115
.downcast_ref::<DictionaryArray<T>>()
116
.unwrap();
117
decoded.0.extend(additional.keys().values().iter().copied());
118
match additional.validity() {
119
Some(v) => decoded.1.extend_from_bitmap(v),
120
None if is_optional => decoded.1.extend_constant(additional.len(), true),
121
None => {},
122
}
123
124
Ok(())
125
}
126
127
fn finalize(
128
&self,
129
dtype: ArrowDataType,
130
dict: Option<Self::Dict>,
131
(values, validity): Self::DecodedState,
132
) -> ParquetResult<DictionaryArray<T>> {
133
let validity = freeze_validity(validity);
134
let dict = dict.unwrap();
135
let keys = PrimitiveArray::new(T::PRIMITIVE.into(), values.into(), validity);
136
137
let mut view_dict = MutableBinaryViewArray::with_capacity(dict.len());
138
let (views, buffers, _, _, _) = dict.into_inner();
139
140
for buffer in buffers.iter() {
141
view_dict.push_buffer(buffer.clone());
142
}
143
unsafe { view_dict.views_mut().extend(views.iter()) };
144
unsafe { view_dict.set_total_bytes_len(views.iter().map(|v| v.length as usize).sum()) };
145
let view_dict = view_dict.freeze();
146
147
// SAFETY: This was checked during construction of the dictionary
148
let dict = unsafe { view_dict.to_utf8view_unchecked() }.boxed();
149
150
// SAFETY: This was checked during decoding
151
Ok(unsafe { DictionaryArray::try_new_unchecked(dtype, keys, dict) }.unwrap())
152
}
153
154
fn extend_filtered_with_state(
155
&mut self,
156
state: utils::State<'_, Self>,
157
decoded: &mut Self::DecodedState,
158
filter: Option<super::Filter>,
159
_chunks: &mut Vec<Self::Output>,
160
) -> ParquetResult<()> {
161
with_cast_mut_vec::<T, T::AlignedBytes, _, _>(&mut decoded.0, |aligned_bytes_vec| {
162
super::dictionary_encoded::decode_dict_dispatch(
163
state.translation,
164
T::try_from(self.dict_size).ok().unwrap(),
165
state.is_optional,
166
state.page_validity.as_ref(),
167
filter,
168
&mut decoded.1,
169
aligned_bytes_vec,
170
)
171
})
172
}
173
174
fn extend_constant(
175
&mut self,
176
_decoded: &mut Self::DecodedState,
177
_length: usize,
178
_value: &crate::read::expr::ParquetScalar,
179
) -> ParquetResult<()> {
180
Err(ParquetError::not_supported(
181
"categorical with pushed-down equality filter",
182
))
183
}
184
}
185
186