Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-row/src/fixed/decimal.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
//! Row Encoding for Enum's and Categorical's
3
//!
4
//! This is a fixed-size encoding that takes a number of maximum bits that each value can take and
5
//! compresses such that a minimum amount of bytes are used for each value.
6
7
use std::mem::MaybeUninit;
8
9
use arrow::array::{Array, PrimitiveArray};
10
use arrow::bitmap::BitmapBuilder;
11
use arrow::datatypes::ArrowDataType;
12
use polars_utils::slice::Slice2Uninit;
13
14
use crate::row::RowEncodingOptions;
15
16
macro_rules! with_constant_num_bytes {
17
($num_bytes:ident, $block:block) => {
18
with_arms!(
19
$num_bytes,
20
$block,
21
(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
22
)
23
};
24
}
25
26
pub fn len_from_precision(precision: usize) -> usize {
27
len_from_num_bits(num_bits_from_precision(precision))
28
}
29
30
fn num_bits_from_precision(precision: usize) -> usize {
31
assert!(precision <= 38);
32
// This may seem random. But this is ceil(s * log2(10)) which is a reduction of
33
// ceil(log2(10**s))
34
((precision as f32) * 10.0f32.log2()).ceil() as usize
35
}
36
37
fn len_from_num_bits(num_bits: usize) -> usize {
38
// 1 bit is used to indicate the nullability
39
// 1 bit is used to indicate the signedness
40
(num_bits + 2).div_ceil(8)
41
}
42
43
pub unsafe fn encode(
44
buffer: &mut [MaybeUninit<u8>],
45
input: &PrimitiveArray<i128>,
46
opt: RowEncodingOptions,
47
offsets: &mut [usize],
48
precision: usize,
49
) {
50
if input.null_count() == 0 {
51
unsafe { encode_slice(buffer, input.values(), opt, offsets, precision) }
52
} else {
53
unsafe {
54
encode_iter(
55
buffer,
56
input.iter().map(|v| v.copied()),
57
opt,
58
offsets,
59
precision,
60
)
61
}
62
}
63
}
64
65
pub unsafe fn encode_slice(
66
buffer: &mut [MaybeUninit<u8>],
67
input: &[i128],
68
opt: RowEncodingOptions,
69
offsets: &mut [usize],
70
precision: usize,
71
) {
72
let num_bits = num_bits_from_precision(precision);
73
74
// If the output will not fit in less bytes, just use the normal i128 encoding kernel.
75
if num_bits >= 127 {
76
super::numeric::encode_slice(buffer, input, opt, offsets);
77
return;
78
}
79
80
let num_bytes = len_from_num_bits(num_bits);
81
let mask = (1 << (num_bits + 1)) - 1;
82
let valid_mask = ((!opt.null_sentinel() & 0x80) as i128) << ((num_bytes - 1) * 8);
83
let sign_mask = 1 << num_bits;
84
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
85
mask
86
} else {
87
0
88
};
89
90
with_constant_num_bytes!(num_bytes, {
91
for (offset, &v) in offsets.iter_mut().zip(input) {
92
let mut v = v;
93
94
v &= mask; // Mask out higher sign extension bits
95
v ^= sign_mask; // Flip sign-bit to maintain order
96
v ^= invert_mask; // Invert for descending
97
v |= valid_mask; // Add valid indicator
98
99
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
100
.copy_from_slice(v.to_be_bytes()[16 - num_bytes..].as_uninit());
101
*offset += num_bytes;
102
}
103
});
104
}
105
106
pub unsafe fn encode_iter(
107
buffer: &mut [MaybeUninit<u8>],
108
input: impl Iterator<Item = Option<i128>>,
109
opt: RowEncodingOptions,
110
offsets: &mut [usize],
111
precision: usize,
112
) {
113
let num_bits = num_bits_from_precision(precision);
114
// If the output will not fit in less bytes, just use the normal i128 encoding kernel.
115
if num_bits >= 127 {
116
super::numeric::encode_iter(buffer, input, opt, offsets);
117
return;
118
}
119
120
let num_bytes = len_from_num_bits(num_bits);
121
let null_value = (opt.null_sentinel() as i128) << ((num_bytes - 1) * 8);
122
let mask = (1 << (num_bits + 1)) - 1;
123
let valid_mask = ((!opt.null_sentinel() & 0x80) as i128) << ((num_bytes - 1) * 8);
124
let sign_mask = 1 << num_bits;
125
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
126
mask
127
} else {
128
0
129
};
130
131
with_constant_num_bytes!(num_bytes, {
132
for (offset, v) in offsets.iter_mut().zip(input) {
133
match v {
134
None => {
135
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
136
.copy_from_slice(null_value.to_be_bytes()[16 - num_bytes..].as_uninit());
137
},
138
Some(mut v) => {
139
v &= mask; // Mask out higher sign extension bits
140
v ^= sign_mask; // Flip sign-bit to maintain order
141
v ^= invert_mask; // Invert for descending
142
v |= valid_mask; // Add valid indicator
143
144
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
145
.copy_from_slice(v.to_be_bytes()[16 - num_bytes..].as_uninit());
146
},
147
}
148
149
*offset += num_bytes;
150
}
151
});
152
}
153
154
pub unsafe fn decode(
155
rows: &mut [&[u8]],
156
opt: RowEncodingOptions,
157
precision: usize,
158
) -> PrimitiveArray<i128> {
159
let num_bits = num_bits_from_precision(precision);
160
// If the output will not fit in less bytes, just use the normal i128 decoding kernel.
161
if num_bits >= 127 {
162
let (_, values, validity) = super::numeric::decode_primitive(rows, opt).into_inner();
163
return PrimitiveArray::new(ArrowDataType::Int128, values, validity);
164
}
165
166
let mut values = Vec::with_capacity(rows.len());
167
let null_sentinel = opt.null_sentinel();
168
169
let num_bytes = len_from_num_bits(num_bits);
170
let mask = (1 << (num_bits + 1)) - 1;
171
let sign_mask = 1 << num_bits;
172
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
173
mask
174
} else {
175
0
176
};
177
178
with_constant_num_bytes!(num_bytes, {
179
values.extend(
180
rows.iter_mut()
181
.take_while(|row| *unsafe { row.get_unchecked(0) } != null_sentinel)
182
.map(|row| {
183
let mut value = 0i128;
184
let value_ref: &mut [u8; 16] = bytemuck::cast_mut(&mut value);
185
value_ref[16 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes));
186
*row = &row[num_bytes..];
187
188
if cfg!(target_endian = "little") {
189
// Big-Endian -> Little-Endian
190
value = value.swap_bytes();
191
}
192
193
value ^= invert_mask; // Invert for descending
194
value ^= sign_mask; // Flip sign bit to maintain order
195
196
// Sign extend. This also masks out the valid bit.
197
value <<= i128::BITS - num_bits as u32 - 1;
198
value >>= i128::BITS - num_bits as u32 - 1;
199
200
value
201
}),
202
);
203
});
204
205
if values.len() == rows.len() {
206
return PrimitiveArray::new(ArrowDataType::Int128, values.into(), None);
207
}
208
209
let mut validity = BitmapBuilder::with_capacity(rows.len());
210
validity.extend_constant(values.len(), true);
211
212
let start_len = values.len();
213
214
with_constant_num_bytes!(num_bytes, {
215
values.extend(rows[start_len..].iter_mut().map(|row| {
216
validity.push(*unsafe { row.get_unchecked(0) } != null_sentinel);
217
218
let mut value = 0i128;
219
let value_ref: &mut [u8; 16] = bytemuck::cast_mut(&mut value);
220
value_ref[16 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes));
221
*row = &row[num_bytes..];
222
223
if cfg!(target_endian = "little") {
224
// Big-Endian -> Little-Endian
225
value = value.swap_bytes();
226
}
227
228
value ^= invert_mask; // Invert for descending
229
value ^= sign_mask; // Flip sign bit to maintain order
230
231
// Sign extend. This also masks out the valid bit.
232
value <<= i128::BITS - num_bits as u32 - 1;
233
value >>= i128::BITS - num_bits as u32 - 1;
234
235
value
236
}));
237
});
238
239
PrimitiveArray::new(
240
ArrowDataType::Int128,
241
values.into(),
242
validity.into_opt_validity(),
243
)
244
}
245
246