Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-row/src/fixed/numeric.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::fmt::Debug;
3
use std::mem::MaybeUninit;
4
5
use arrow::array::{Array, PrimitiveArray};
6
use arrow::bitmap::Bitmap;
7
use arrow::datatypes::ArrowDataType;
8
use arrow::types::NativeType;
9
use polars_utils::slice::*;
10
use polars_utils::total_ord::{canonical_f32, canonical_f64};
11
12
use crate::row::RowEncodingOptions;
13
pub(crate) trait FromSlice {
14
fn from_slice(slice: &[u8]) -> Self;
15
}
16
17
impl<const N: usize> FromSlice for [u8; N] {
18
#[inline]
19
fn from_slice(slice: &[u8]) -> Self {
20
slice.try_into().unwrap()
21
}
22
}
23
24
/// Encodes a value of a particular fixed width type into bytes
25
pub trait FixedLengthEncoding: Copy + Debug {
26
// 1 is validity 0 or 1
27
// bit repr of encoding
28
const ENCODED_LEN: usize = 1 + size_of::<Self::Encoded>();
29
30
type Encoded: Sized + Copy + AsRef<[u8]> + AsMut<[u8]>;
31
32
fn encode(self) -> Self::Encoded;
33
34
fn decode(encoded: Self::Encoded) -> Self;
35
36
fn decode_reverse(mut encoded: Self::Encoded) -> Self {
37
for v in encoded.as_mut() {
38
*v = !*v
39
}
40
Self::decode(encoded)
41
}
42
}
43
44
// encode as big endian
45
macro_rules! encode_unsigned {
46
($n:expr, $t:ty) => {
47
impl FixedLengthEncoding for $t {
48
type Encoded = [u8; $n];
49
50
fn encode(self) -> [u8; $n] {
51
self.to_be_bytes()
52
}
53
54
fn decode(encoded: Self::Encoded) -> Self {
55
Self::from_be_bytes(encoded)
56
}
57
}
58
};
59
}
60
61
encode_unsigned!(1, u8);
62
encode_unsigned!(2, u16);
63
encode_unsigned!(4, u32);
64
encode_unsigned!(8, u64);
65
66
// toggle the sign bit and then encode as big indian
67
macro_rules! encode_signed {
68
($n:expr, $t:ty) => {
69
impl FixedLengthEncoding for $t {
70
type Encoded = [u8; $n];
71
72
fn encode(self) -> [u8; $n] {
73
#[cfg(target_endian = "big")]
74
{
75
todo!()
76
}
77
78
let mut b = self.to_be_bytes();
79
// Toggle top "sign" bit to ensure consistent sort order
80
b[0] ^= 0x80;
81
b
82
}
83
84
fn decode(mut encoded: Self::Encoded) -> Self {
85
// Toggle top "sign" bit
86
encoded[0] ^= 0x80;
87
Self::from_be_bytes(encoded)
88
}
89
}
90
};
91
}
92
93
encode_signed!(1, i8);
94
encode_signed!(2, i16);
95
encode_signed!(4, i32);
96
encode_signed!(8, i64);
97
encode_signed!(16, i128);
98
99
impl FixedLengthEncoding for f32 {
100
type Encoded = [u8; 4];
101
102
fn encode(self) -> [u8; 4] {
103
// https://github.com/rust-lang/rust/blob/9c20b2a8cc7588decb6de25ac6a7912dcef24d65/library/core/src/num/f32.rs#L1176-L1260
104
let s = canonical_f32(self).to_bits() as i32;
105
let val = s ^ (((s >> 31) as u32) >> 1) as i32;
106
val.encode()
107
}
108
109
fn decode(encoded: Self::Encoded) -> Self {
110
let bits = i32::decode(encoded);
111
let val = bits ^ (((bits >> 31) as u32) >> 1) as i32;
112
Self::from_bits(val as u32)
113
}
114
}
115
116
impl FixedLengthEncoding for f64 {
117
type Encoded = [u8; 8];
118
119
fn encode(self) -> [u8; 8] {
120
// https://github.com/rust-lang/rust/blob/9c20b2a8cc7588decb6de25ac6a7912dcef24d65/library/core/src/num/f32.rs#L1176-L1260
121
let s = canonical_f64(self).to_bits() as i64;
122
let val = s ^ (((s >> 63) as u64) >> 1) as i64;
123
val.encode()
124
}
125
126
fn decode(encoded: Self::Encoded) -> Self {
127
let bits = i64::decode(encoded);
128
let val = bits ^ (((bits >> 63) as u64) >> 1) as i64;
129
Self::from_bits(val as u64)
130
}
131
}
132
133
pub unsafe fn encode<T: NativeType + FixedLengthEncoding>(
134
buffer: &mut [MaybeUninit<u8>],
135
arr: &PrimitiveArray<T>,
136
opt: RowEncodingOptions,
137
offsets: &mut [usize],
138
) {
139
if arr.null_count() == 0 {
140
crate::fixed::numeric::encode_slice(buffer, arr.values().as_slice(), opt, offsets)
141
} else {
142
crate::fixed::numeric::encode_iter(
143
buffer,
144
arr.into_iter().map(|v| v.copied()),
145
opt,
146
offsets,
147
)
148
}
149
}
150
151
#[inline]
152
unsafe fn encode_value<T: FixedLengthEncoding>(
153
value: &T,
154
offset: &mut usize,
155
descending: bool,
156
buf: &mut [MaybeUninit<u8>],
157
) {
158
let end_offset = *offset + T::ENCODED_LEN;
159
let dst = unsafe { buf.get_unchecked_mut(*offset..end_offset) };
160
// set valid
161
dst[0] = MaybeUninit::new(1);
162
let mut encoded = value.encode();
163
164
// invert bits to reverse order
165
if descending {
166
for v in encoded.as_mut() {
167
*v = !*v
168
}
169
}
170
171
dst[1..].copy_from_slice(encoded.as_ref().as_uninit());
172
*offset = end_offset;
173
}
174
175
unsafe fn encode_opt_value<T: FixedLengthEncoding>(
176
opt_value: Option<T>,
177
offset: &mut usize,
178
opt: RowEncodingOptions,
179
buffer: &mut [MaybeUninit<u8>],
180
) {
181
let descending = opt.contains(RowEncodingOptions::DESCENDING);
182
if let Some(value) = opt_value {
183
encode_value(&value, offset, descending, buffer);
184
} else {
185
unsafe { *buffer.get_unchecked_mut(*offset) = MaybeUninit::new(opt.null_sentinel()) };
186
let end_offset = *offset + T::ENCODED_LEN;
187
188
// initialize remaining bytes
189
let remainder = unsafe { buffer.get_unchecked_mut(*offset + 1..end_offset) };
190
remainder.fill(MaybeUninit::new(0));
191
192
*offset = end_offset;
193
}
194
}
195
196
pub(crate) unsafe fn encode_slice<T: FixedLengthEncoding>(
197
buffer: &mut [MaybeUninit<u8>],
198
input: &[T],
199
opt: RowEncodingOptions,
200
row_starts: &mut [usize],
201
) {
202
let descending = opt.contains(RowEncodingOptions::DESCENDING);
203
for (offset, value) in row_starts.iter_mut().zip(input) {
204
encode_value(value, offset, descending, buffer);
205
}
206
}
207
208
pub(crate) unsafe fn encode_iter<I: Iterator<Item = Option<T>>, T: FixedLengthEncoding>(
209
buffer: &mut [MaybeUninit<u8>],
210
input: I,
211
opt: RowEncodingOptions,
212
row_starts: &mut [usize],
213
) {
214
for (offset, opt_value) in row_starts.iter_mut().zip(input) {
215
encode_opt_value(opt_value, offset, opt, buffer);
216
}
217
}
218
219
pub(crate) unsafe fn decode_primitive<T: NativeType + FixedLengthEncoding>(
220
rows: &mut [&[u8]],
221
opt: RowEncodingOptions,
222
) -> PrimitiveArray<T>
223
where
224
T::Encoded: FromSlice,
225
{
226
let dtype: ArrowDataType = T::PRIMITIVE.into();
227
let mut has_nulls = false;
228
let descending = opt.contains(RowEncodingOptions::DESCENDING);
229
let null_sentinel = opt.null_sentinel();
230
231
let values = rows
232
.iter()
233
.map(|row| {
234
has_nulls |= *row.get_unchecked(0) == null_sentinel;
235
// skip null sentinel
236
let start = 1;
237
let end = start + T::ENCODED_LEN - 1;
238
let slice = row.get_unchecked(start..end);
239
let bytes = T::Encoded::from_slice(slice);
240
241
if descending {
242
T::decode_reverse(bytes)
243
} else {
244
T::decode(bytes)
245
}
246
})
247
.collect::<Vec<_>>();
248
249
let validity = if has_nulls {
250
let null_sentinel = opt.null_sentinel();
251
Some(decode_nulls(rows, null_sentinel))
252
} else {
253
None
254
};
255
256
// validity byte and data length
257
let increment_len = T::ENCODED_LEN;
258
259
increment_row_counter(rows, increment_len);
260
PrimitiveArray::new(dtype, values.into(), validity)
261
}
262
263
unsafe fn increment_row_counter(rows: &mut [&[u8]], fixed_size: usize) {
264
for row in rows {
265
*row = row.get_unchecked(fixed_size..);
266
}
267
}
268
269
pub(super) unsafe fn decode_nulls(rows: &[&[u8]], null_sentinel: u8) -> Bitmap {
270
rows.iter()
271
.map(|row| *row.get_unchecked(0) != null_sentinel)
272
.collect()
273
}
274
275