Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-row/src/decode.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::bitmap::{Bitmap, BitmapBuilder};
3
use arrow::buffer::Buffer;
4
use arrow::datatypes::ArrowDataType;
5
use arrow::offset::OffsetsBuffer;
6
use arrow::types::NativeType;
7
use polars_dtype::categorical::CatNative;
8
9
use self::encode::fixed_size;
10
use self::row::{RowEncodingCategoricalContext, RowEncodingOptions};
11
use self::variable::utf8::decode_str;
12
use super::*;
13
use crate::fixed::numeric::{FixedLengthEncoding, FromSlice};
14
use crate::fixed::{boolean, decimal, numeric};
15
use crate::variable::{binary, no_order, utf8};
16
17
/// Decode `rows` into a arrow format
18
/// # Safety
19
/// This will not do any bound checks. Caller must ensure the `rows` are valid
20
/// encodings.
21
pub unsafe fn decode_rows_from_binary<'a>(
22
arr: &'a BinaryArray<i64>,
23
opts: &[RowEncodingOptions],
24
dicts: &[Option<RowEncodingContext>],
25
dtypes: &[ArrowDataType],
26
rows: &mut Vec<&'a [u8]>,
27
) -> Vec<ArrayRef> {
28
assert_eq!(arr.null_count(), 0);
29
rows.clear();
30
rows.extend(arr.values_iter());
31
decode_rows(rows, opts, dicts, dtypes)
32
}
33
34
/// Decode `rows` into a arrow format
35
/// # Safety
36
/// This will not do any bound checks. Caller must ensure the `rows` are valid
37
/// encodings.
38
pub unsafe fn decode_rows(
39
// the rows will be updated while the data is decoded
40
rows: &mut [&[u8]],
41
opts: &[RowEncodingOptions],
42
dicts: &[Option<RowEncodingContext>],
43
dtypes: &[ArrowDataType],
44
) -> Vec<ArrayRef> {
45
assert_eq!(opts.len(), dtypes.len());
46
assert_eq!(dicts.len(), dtypes.len());
47
48
dtypes
49
.iter()
50
.zip(opts)
51
.zip(dicts)
52
.map(|((dtype, opt), dict)| decode(rows, *opt, dict.as_ref(), dtype))
53
.collect()
54
}
55
56
unsafe fn decode_validity(rows: &mut [&[u8]], opt: RowEncodingOptions) -> Option<Bitmap> {
57
// 2 loop system to avoid the overhead of allocating the bitmap if all the elements are valid.
58
59
let null_sentinel = opt.null_sentinel();
60
let first_null = (0..rows.len()).find(|&i| {
61
let v;
62
(v, rows[i]) = rows[i].split_at_unchecked(1);
63
v[0] == null_sentinel
64
});
65
66
// No nulls just return None
67
let first_null = first_null?;
68
69
let mut bm = BitmapBuilder::new();
70
bm.reserve(rows.len());
71
bm.extend_constant(first_null, true);
72
bm.push(false);
73
bm.extend_trusted_len_iter(rows[first_null + 1..].iter_mut().map(|row| {
74
let v;
75
(v, *row) = row.split_at_unchecked(1);
76
v[0] != null_sentinel
77
}));
78
bm.into_opt_validity()
79
}
80
81
// We inline this in an attempt to avoid the dispatch cost.
82
#[inline(always)]
83
fn dtype_and_data_to_encoded_item_len(
84
dtype: &ArrowDataType,
85
data: &[u8],
86
opt: RowEncodingOptions,
87
dict: Option<&RowEncodingContext>,
88
) -> usize {
89
// Fast path: if the size is fixed, we can just divide.
90
if let Some(size) = fixed_size(dtype, opt, dict) {
91
return size;
92
}
93
94
use ArrowDataType as D;
95
match dtype {
96
D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View
97
if opt.contains(RowEncodingOptions::NO_ORDER) =>
98
unsafe { no_order::len_from_buffer(data, opt) },
99
D::Binary | D::LargeBinary | D::BinaryView => unsafe {
100
binary::encoded_item_len(data, opt)
101
},
102
D::Utf8 | D::LargeUtf8 | D::Utf8View => unsafe { utf8::len_from_buffer(data, opt) },
103
104
D::List(list_field) | D::LargeList(list_field) => {
105
let mut data = data;
106
let mut item_len = 0;
107
108
let list_continuation_token = opt.list_continuation_token();
109
110
while data[0] == list_continuation_token {
111
data = &data[1..];
112
let len = dtype_and_data_to_encoded_item_len(list_field.dtype(), data, opt, dict);
113
data = &data[len..];
114
item_len += 1 + len;
115
}
116
1 + item_len
117
},
118
119
D::FixedSizeBinary(_) => todo!(),
120
D::FixedSizeList(fsl_field, width) => {
121
let mut data = &data[1..];
122
let mut item_len = 1; // validity byte
123
124
for _ in 0..*width {
125
let len = dtype_and_data_to_encoded_item_len(
126
fsl_field.dtype(),
127
data,
128
opt.into_nested(),
129
dict,
130
);
131
data = &data[len..];
132
item_len += len;
133
}
134
item_len
135
},
136
D::Struct(struct_fields) => {
137
let mut data = &data[1..];
138
let mut item_len = 1; // validity byte
139
140
for struct_field in struct_fields {
141
let len = dtype_and_data_to_encoded_item_len(
142
struct_field.dtype(),
143
data,
144
opt.into_nested(),
145
dict,
146
);
147
data = &data[len..];
148
item_len += len;
149
}
150
item_len
151
},
152
153
D::Union(_) => todo!(),
154
D::Map(_, _) => todo!(),
155
D::Decimal32(_, _) => todo!(),
156
D::Decimal64(_, _) => todo!(),
157
D::Decimal256(_, _) => todo!(),
158
D::Extension(_) => todo!(),
159
D::Unknown => todo!(),
160
161
_ => unreachable!(),
162
}
163
}
164
165
fn rows_for_fixed_size_list<'a>(
166
dtype: &ArrowDataType,
167
opt: RowEncodingOptions,
168
dict: Option<&RowEncodingContext>,
169
width: usize,
170
rows: &mut [&'a [u8]],
171
nested_rows: &mut Vec<&'a [u8]>,
172
) {
173
nested_rows.clear();
174
nested_rows.reserve(rows.len() * width);
175
176
// Fast path: if the size is fixed, we can just divide.
177
if let Some(size) = fixed_size(dtype, opt, dict) {
178
for row in rows.iter_mut() {
179
for i in 0..width {
180
nested_rows.push(&row[(i * size)..][..size]);
181
}
182
*row = &row[size * width..];
183
}
184
return;
185
}
186
187
// @TODO: This is quite slow since we need to dispatch for possibly every nested type
188
for row in rows.iter_mut() {
189
for _ in 0..width {
190
let length = dtype_and_data_to_encoded_item_len(dtype, row, opt.into_nested(), dict);
191
let v;
192
(v, *row) = row.split_at(length);
193
nested_rows.push(v);
194
}
195
}
196
}
197
198
unsafe fn decode_cat<T: NativeType + FixedLengthEncoding + CatNative>(
199
rows: &mut [&[u8]],
200
opt: RowEncodingOptions,
201
ctx: &RowEncodingCategoricalContext,
202
) -> PrimitiveArray<T>
203
where
204
T::Encoded: FromSlice,
205
{
206
if ctx.is_enum || !opt.is_ordered() {
207
numeric::decode_primitive::<T>(rows, opt)
208
} else {
209
variable::utf8::decode_str_as_cat::<T>(rows, opt, &ctx.mapping)
210
}
211
}
212
213
unsafe fn decode(
214
rows: &mut [&[u8]],
215
opt: RowEncodingOptions,
216
dict: Option<&RowEncodingContext>,
217
dtype: &ArrowDataType,
218
) -> ArrayRef {
219
use ArrowDataType as D;
220
221
if let Some(RowEncodingContext::Categorical(ctx)) = dict {
222
match dtype {
223
D::UInt8 => return decode_cat::<u8>(rows, opt, ctx).to_boxed(),
224
D::UInt16 => return decode_cat::<u16>(rows, opt, ctx).to_boxed(),
225
D::UInt32 => return decode_cat::<u32>(rows, opt, ctx).to_boxed(),
226
D::FixedSizeList(..) | D::List(_) | D::LargeList(_) => {
227
// Nested type, handled below.
228
},
229
_ => unreachable!(),
230
};
231
}
232
233
match dtype {
234
D::Null => NullArray::new(D::Null, rows.len()).to_boxed(),
235
D::Boolean => boolean::decode_bool(rows, opt).to_boxed(),
236
D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View
237
if opt.contains(RowEncodingOptions::NO_ORDER) =>
238
{
239
let array = no_order::decode_variable_no_order(rows, opt);
240
241
if matches!(dtype, D::Utf8 | D::LargeUtf8 | D::Utf8View) {
242
unsafe { array.to_utf8view_unchecked() }.to_boxed()
243
} else {
244
array.to_boxed()
245
}
246
},
247
D::Binary | D::LargeBinary | D::BinaryView => binary::decode_binview(rows, opt).to_boxed(),
248
D::Utf8 | D::LargeUtf8 | D::Utf8View => decode_str(rows, opt).boxed(),
249
250
D::Struct(fields) => {
251
let validity = decode_validity(rows, opt);
252
253
let values = match dict {
254
None => fields
255
.iter()
256
.map(|struct_fld| decode(rows, opt.into_nested(), None, struct_fld.dtype()))
257
.collect(),
258
Some(RowEncodingContext::Struct(dicts)) => fields
259
.iter()
260
.zip(dicts)
261
.map(|(struct_fld, dict)| {
262
decode(rows, opt.into_nested(), dict.as_ref(), struct_fld.dtype())
263
})
264
.collect(),
265
_ => unreachable!(),
266
};
267
StructArray::new(dtype.clone(), rows.len(), values, validity).to_boxed()
268
},
269
D::FixedSizeList(fsl_field, width) => {
270
let validity = decode_validity(rows, opt);
271
272
// @TODO: we could consider making this into a scratchpad
273
let mut nested_rows = Vec::new();
274
rows_for_fixed_size_list(
275
fsl_field.dtype(),
276
opt.into_nested(),
277
dict,
278
*width,
279
rows,
280
&mut nested_rows,
281
);
282
283
let values = decode(&mut nested_rows, opt.into_nested(), dict, fsl_field.dtype());
284
285
FixedSizeListArray::new(dtype.clone(), rows.len(), values, validity).to_boxed()
286
},
287
D::List(list_field) | D::LargeList(list_field) => {
288
let mut validity = BitmapBuilder::new();
289
290
// @TODO: we could consider making this into a scratchpad
291
let num_rows = rows.len();
292
let mut nested_rows = Vec::new();
293
let mut offsets = Vec::with_capacity(rows.len() + 1);
294
offsets.push(0);
295
296
let list_null_sentinel = opt.list_null_sentinel();
297
let list_continuation_token = opt.list_continuation_token();
298
let list_termination_token = opt.list_termination_token();
299
300
// @TODO: make a specialized loop for fixed size list_field.dtype()
301
for (i, row) in rows.iter_mut().enumerate() {
302
while row[0] == list_continuation_token {
303
*row = &row[1..];
304
let len = dtype_and_data_to_encoded_item_len(
305
list_field.dtype(),
306
row,
307
opt.into_nested(),
308
dict,
309
);
310
nested_rows.push(&row[..len]);
311
*row = &row[len..];
312
}
313
314
offsets.push(nested_rows.len() as i64);
315
316
// @TODO: Might be better to make this a 2-loop system.
317
if row[0] == list_null_sentinel {
318
*row = &row[1..];
319
validity.reserve(num_rows);
320
validity.extend_constant(i - validity.len(), true);
321
validity.push(false);
322
continue;
323
}
324
325
assert_eq!(row[0], list_termination_token);
326
*row = &row[1..];
327
}
328
329
let validity = if validity.is_empty() {
330
None
331
} else {
332
validity.extend_constant(num_rows - validity.len(), true);
333
validity.into_opt_validity()
334
};
335
assert_eq!(offsets.len(), rows.len() + 1);
336
337
let values = decode(
338
&mut nested_rows,
339
opt.into_nested(),
340
dict,
341
list_field.dtype(),
342
);
343
344
ListArray::<i64>::new(
345
dtype.clone(),
346
unsafe { OffsetsBuffer::new_unchecked(Buffer::from(offsets)) },
347
values,
348
validity,
349
)
350
.to_boxed()
351
},
352
353
dt => {
354
if matches!(dt, D::Int128) {
355
if let Some(dict) = dict {
356
return match dict {
357
RowEncodingContext::Decimal(precision) => {
358
decimal::decode(rows, opt, *precision).to_boxed()
359
},
360
_ => unreachable!(),
361
};
362
}
363
}
364
365
with_match_arrow_primitive_type!(dt, |$T| {
366
numeric::decode_primitive::<$T>(rows, opt).to_boxed()
367
})
368
},
369
}
370
}
371
372