#![allow(unsafe_op_in_unsafe_fn)]
use std::mem::MaybeUninit;
use arrow::array::{MutableBinaryViewArray, PrimitiveArray, Utf8ViewArray};
use arrow::bitmap::BitmapBuilder;
use arrow::types::NativeType;
use polars_dtype::categorical::{CatNative, CategoricalMapping};
use crate::row::RowEncodingOptions;
#[inline]
pub fn len_from_item(a: Option<usize>, _opt: RowEncodingOptions) -> usize {
1 + a.unwrap_or_default()
}
pub unsafe fn len_from_buffer(row: &[u8], opt: RowEncodingOptions) -> usize {
if *row.get_unchecked(0) == opt.null_sentinel() {
return 1;
}
let end = if opt.contains(RowEncodingOptions::DESCENDING) {
unsafe { row.iter().position(|&b| b == 0xFE).unwrap_unchecked() }
} else {
unsafe { row.iter().position(|&b| b == 0x01).unwrap_unchecked() }
};
end + 1
}
pub unsafe fn encode_str<'a, I: Iterator<Item = Option<&'a str>>>(
buffer: &mut [MaybeUninit<u8>],
input: I,
opt: RowEncodingOptions,
offsets: &mut [usize],
) {
let null_sentinel = opt.null_sentinel();
let t = if opt.contains(RowEncodingOptions::DESCENDING) {
0xFF
} else {
0x00
};
for (offset, opt_value) in offsets.iter_mut().zip(input) {
let dst = buffer.get_unchecked_mut(*offset..);
match opt_value {
None => {
*unsafe { dst.get_unchecked_mut(0) } = MaybeUninit::new(null_sentinel);
*offset += 1;
},
Some(s) => {
for (i, &b) in s.as_bytes().iter().enumerate() {
*unsafe { dst.get_unchecked_mut(i) } = MaybeUninit::new(t ^ (b + 2));
}
*unsafe { dst.get_unchecked_mut(s.len()) } = MaybeUninit::new(t ^ 0x01);
*offset += 1 + s.len();
},
}
}
}
pub unsafe fn decode_str(rows: &mut [&[u8]], opt: RowEncodingOptions) -> Utf8ViewArray {
let null_sentinel = opt.null_sentinel();
let descending = opt.contains(RowEncodingOptions::DESCENDING);
let num_rows = rows.len();
let mut array = MutableBinaryViewArray::<str>::with_capacity(rows.len());
let mut scratch = Vec::new();
for row in rows.iter_mut() {
let sentinel = *unsafe { row.get_unchecked(0) };
if sentinel == null_sentinel {
*row = unsafe { row.get_unchecked(1..) };
break;
}
scratch.clear();
if descending {
scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2));
} else {
scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2));
}
*row = row.get_unchecked(1 + scratch.len()..);
array.push_value_ignore_validity(unsafe { std::str::from_utf8_unchecked(&scratch) });
}
if array.len() == num_rows {
return array.into();
}
let mut validity = BitmapBuilder::with_capacity(num_rows);
validity.extend_constant(array.len(), true);
validity.push(false);
array.push_value_ignore_validity("");
for row in rows[array.len()..].iter_mut() {
let sentinel = *unsafe { row.get_unchecked(0) };
validity.push(sentinel != null_sentinel);
if sentinel == null_sentinel {
*row = unsafe { row.get_unchecked(1..) };
array.push_value_ignore_validity("");
continue;
}
scratch.clear();
if descending {
scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2));
} else {
scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2));
}
*row = row.get_unchecked(1 + scratch.len()..);
array.push_value_ignore_validity(unsafe { std::str::from_utf8_unchecked(&scratch) });
}
let out: Utf8ViewArray = array.into();
out.with_validity(validity.into_opt_validity())
}
pub unsafe fn decode_str_as_cat<T: NativeType + CatNative>(
rows: &mut [&[u8]],
opt: RowEncodingOptions,
mapping: &CategoricalMapping,
) -> PrimitiveArray<T> {
let null_sentinel = opt.null_sentinel();
let descending = opt.contains(RowEncodingOptions::DESCENDING);
let num_rows = rows.len();
let mut out = Vec::<T>::with_capacity(rows.len());
let mut scratch = Vec::new();
for row in rows.iter_mut() {
let sentinel = *unsafe { row.get_unchecked(0) };
if sentinel == null_sentinel {
*row = unsafe { row.get_unchecked(1..) };
break;
}
scratch.clear();
if descending {
scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2));
} else {
scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2));
}
*row = row.get_unchecked(1 + scratch.len()..);
let s = unsafe { std::str::from_utf8_unchecked(&scratch) };
out.push(T::from_cat(mapping.insert_cat(s).unwrap()));
}
if out.len() == num_rows {
return PrimitiveArray::from_vec(out);
}
let mut validity = BitmapBuilder::with_capacity(num_rows);
validity.extend_constant(out.len(), true);
validity.push(false);
out.push(T::zeroed());
for row in rows[out.len()..].iter_mut() {
let sentinel = *unsafe { row.get_unchecked(0) };
validity.push(sentinel != null_sentinel);
if sentinel == null_sentinel {
*row = unsafe { row.get_unchecked(1..) };
out.push(T::zeroed());
continue;
}
scratch.clear();
if descending {
scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2));
} else {
scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2));
}
*row = row.get_unchecked(1 + scratch.len()..);
let s = unsafe { std::str::from_utf8_unchecked(&scratch) };
out.push(T::from_cat(mapping.insert_cat(s).unwrap()));
}
PrimitiveArray::from_vec(out).with_validity(validity.into_opt_validity())
}