#![allow(unsafe_op_in_unsafe_fn)]
use std::mem::MaybeUninit;
use arrow::array::{Array, PrimitiveArray};
use arrow::bitmap::BitmapBuilder;
use arrow::datatypes::ArrowDataType;
use polars_utils::slice::Slice2Uninit;
use crate::row::RowEncodingOptions;
macro_rules! with_constant_num_bytes {
($num_bytes:ident, $block:block) => {
with_arms!(
$num_bytes,
$block,
(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
)
};
}
pub fn len_from_precision(precision: usize) -> usize {
len_from_num_bits(num_bits_from_precision(precision))
}
fn num_bits_from_precision(precision: usize) -> usize {
assert!(precision <= 38);
((precision as f32) * 10.0f32.log2()).ceil() as usize
}
fn len_from_num_bits(num_bits: usize) -> usize {
(num_bits + 2).div_ceil(8)
}
pub unsafe fn encode(
buffer: &mut [MaybeUninit<u8>],
input: &PrimitiveArray<i128>,
opt: RowEncodingOptions,
offsets: &mut [usize],
precision: usize,
) {
if input.null_count() == 0 {
unsafe { encode_slice(buffer, input.values(), opt, offsets, precision) }
} else {
unsafe {
encode_iter(
buffer,
input.iter().map(|v| v.copied()),
opt,
offsets,
precision,
)
}
}
}
pub unsafe fn encode_slice(
buffer: &mut [MaybeUninit<u8>],
input: &[i128],
opt: RowEncodingOptions,
offsets: &mut [usize],
precision: usize,
) {
let num_bits = num_bits_from_precision(precision);
if num_bits >= 127 {
super::numeric::encode_slice(buffer, input, opt, offsets);
return;
}
let num_bytes = len_from_num_bits(num_bits);
let mask = (1 << (num_bits + 1)) - 1;
let valid_mask = ((!opt.null_sentinel() & 0x80) as i128) << ((num_bytes - 1) * 8);
let sign_mask = 1 << num_bits;
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
mask
} else {
0
};
with_constant_num_bytes!(num_bytes, {
for (offset, &v) in offsets.iter_mut().zip(input) {
let mut v = v;
v &= mask;
v ^= sign_mask;
v ^= invert_mask;
v |= valid_mask;
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
.copy_from_slice(v.to_be_bytes()[16 - num_bytes..].as_uninit());
*offset += num_bytes;
}
});
}
pub unsafe fn encode_iter(
buffer: &mut [MaybeUninit<u8>],
input: impl Iterator<Item = Option<i128>>,
opt: RowEncodingOptions,
offsets: &mut [usize],
precision: usize,
) {
let num_bits = num_bits_from_precision(precision);
if num_bits >= 127 {
super::numeric::encode_iter(buffer, input, opt, offsets);
return;
}
let num_bytes = len_from_num_bits(num_bits);
let null_value = (opt.null_sentinel() as i128) << ((num_bytes - 1) * 8);
let mask = (1 << (num_bits + 1)) - 1;
let valid_mask = ((!opt.null_sentinel() & 0x80) as i128) << ((num_bytes - 1) * 8);
let sign_mask = 1 << num_bits;
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
mask
} else {
0
};
with_constant_num_bytes!(num_bytes, {
for (offset, v) in offsets.iter_mut().zip(input) {
match v {
None => {
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
.copy_from_slice(null_value.to_be_bytes()[16 - num_bytes..].as_uninit());
},
Some(mut v) => {
v &= mask;
v ^= sign_mask;
v ^= invert_mask;
v |= valid_mask;
unsafe { buffer.get_unchecked_mut(*offset..*offset + num_bytes) }
.copy_from_slice(v.to_be_bytes()[16 - num_bytes..].as_uninit());
},
}
*offset += num_bytes;
}
});
}
pub unsafe fn decode(
rows: &mut [&[u8]],
opt: RowEncodingOptions,
precision: usize,
) -> PrimitiveArray<i128> {
let num_bits = num_bits_from_precision(precision);
if num_bits >= 127 {
let (_, values, validity) = super::numeric::decode_primitive(rows, opt).into_inner();
return PrimitiveArray::new(ArrowDataType::Int128, values, validity);
}
let mut values = Vec::with_capacity(rows.len());
let null_sentinel = opt.null_sentinel();
let num_bytes = len_from_num_bits(num_bits);
let mask = (1 << (num_bits + 1)) - 1;
let sign_mask = 1 << num_bits;
let invert_mask = if opt.contains(RowEncodingOptions::DESCENDING) {
mask
} else {
0
};
with_constant_num_bytes!(num_bytes, {
values.extend(
rows.iter_mut()
.take_while(|row| *unsafe { row.get_unchecked(0) } != null_sentinel)
.map(|row| {
let mut value = 0i128;
let value_ref: &mut [u8; 16] = bytemuck::cast_mut(&mut value);
value_ref[16 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes));
*row = &row[num_bytes..];
if cfg!(target_endian = "little") {
value = value.swap_bytes();
}
value ^= invert_mask;
value ^= sign_mask;
value <<= i128::BITS - num_bits as u32 - 1;
value >>= i128::BITS - num_bits as u32 - 1;
value
}),
);
});
if values.len() == rows.len() {
return PrimitiveArray::new(ArrowDataType::Int128, values.into(), None);
}
let mut validity = BitmapBuilder::with_capacity(rows.len());
validity.extend_constant(values.len(), true);
let start_len = values.len();
with_constant_num_bytes!(num_bytes, {
values.extend(rows[start_len..].iter_mut().map(|row| {
validity.push(*unsafe { row.get_unchecked(0) } != null_sentinel);
let mut value = 0i128;
let value_ref: &mut [u8; 16] = bytemuck::cast_mut(&mut value);
value_ref[16 - num_bytes..].copy_from_slice(row.get_unchecked(..num_bytes));
*row = &row[num_bytes..];
if cfg!(target_endian = "little") {
value = value.swap_bytes();
}
value ^= invert_mask;
value ^= sign_mask;
value <<= i128::BITS - num_bits as u32 - 1;
value >>= i128::BITS - num_bits as u32 - 1;
value
}));
});
PrimitiveArray::new(
ArrowDataType::Int128,
values.into(),
validity.into_opt_validity(),
)
}