Path: blob/main/crates/polars-compute/src/cast/utf8_to.rs
6939 views
use std::sync::Arc;12use arrow::array::*;3use arrow::buffer::Buffer;4use arrow::datatypes::ArrowDataType;5use arrow::offset::Offset;6use arrow::types::NativeType;7use polars_error::PolarsResult;8use polars_utils::vec::PushUnchecked;910pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z";1112pub(super) fn utf8_to_dictionary_dyn<O: Offset, K: DictionaryKey>(13from: &dyn Array,14) -> PolarsResult<Box<dyn Array>> {15let values = from.as_any().downcast_ref().unwrap();16utf8_to_dictionary::<O, K>(values).map(|x| Box::new(x) as Box<dyn Array>)17}1819/// Cast [`Utf8Array`] to [`DictionaryArray`], also known as packing.20/// # Errors21/// This function errors if the maximum key is smaller than the number of distinct elements22/// in the array.23pub fn utf8_to_dictionary<O: Offset, K: DictionaryKey>(24from: &Utf8Array<O>,25) -> PolarsResult<DictionaryArray<K>> {26let mut array = MutableDictionaryArray::<K, MutableUtf8Array<O>>::new();27array.reserve(from.len());28array.try_extend(from.iter())?;2930Ok(array.into())31}3233/// Conversion of utf834pub fn utf8_to_large_utf8(from: &Utf8Array<i32>) -> Utf8Array<i64> {35let dtype = Utf8Array::<i64>::default_dtype();36let validity = from.validity().cloned();37let values = from.values().clone();3839let offsets = from.offsets().into();40// SAFETY: sound because `values` fulfills the same invariants as `from.values()`41unsafe { Utf8Array::<i64>::new_unchecked(dtype, offsets, values, validity) }42}4344/// Conversion of utf845pub fn utf8_large_to_utf8(from: &Utf8Array<i64>) -> PolarsResult<Utf8Array<i32>> {46let dtype = Utf8Array::<i32>::default_dtype();47let validity = from.validity().cloned();48let values = from.values().clone();49let offsets = from.offsets().try_into()?;5051// SAFETY: sound because `values` fulfills the same invariants as `from.values()`52Ok(unsafe { Utf8Array::<i32>::new_unchecked(dtype, offsets, values, validity) })53}5455/// Conversion to binary56pub fn utf8_to_binary<O: Offset>(from: &Utf8Array<O>, to_dtype: ArrowDataType) -> BinaryArray<O> {57// SAFETY: erasure of an invariant is always safe58BinaryArray::<O>::new(59to_dtype,60from.offsets().clone(),61from.values().clone(),62from.validity().cloned(),63)64}6566// Different types to test the overflow path.67#[cfg(not(test))]68type OffsetType = u32;6970// To trigger overflow71#[cfg(test)]72type OffsetType = i8;7374// If we don't do this the GC of binview will trigger. As we will split up buffers into multiple75// chunks so that we don't overflow the offset u32.76fn truncate_buffer(buf: &Buffer<u8>) -> Buffer<u8> {77// * 2, as it must be able to hold u32::MAX offset + u32::MAX len.78buf.clone().sliced(790,80std::cmp::min(buf.len(), ((OffsetType::MAX as u64) * 2) as usize),81)82}8384pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {85// Ensure we didn't accidentally set wrong type86#[cfg(not(debug_assertions))]87let _ = std::mem::transmute::<OffsetType, u32>;8889let mut views = Vec::with_capacity(arr.len());90let mut uses_buffer = false;9192let mut base_buffer = arr.values().clone();93// Offset into the buffer94let mut base_ptr = base_buffer.as_ptr() as usize;9596// Offset into the binview buffers97let mut buffer_idx = 0_u32;9899// Binview buffers100// Note that the buffer may look far further than u32::MAX, but as we don't clone data101let mut buffers = vec![truncate_buffer(&base_buffer)];102103for bytes in arr.values_iter() {104let len: u32 = bytes105.len()106.try_into()107.expect("max string/binary length exceeded");108109let mut payload = [0; 16];110payload[0..4].copy_from_slice(&len.to_le_bytes());111112if len <= 12 {113payload[4..4 + bytes.len()].copy_from_slice(bytes);114} else {115uses_buffer = true;116117// Copy the parts we know are correct.118unsafe { payload[4..8].copy_from_slice(bytes.get_unchecked(0..4)) };119payload[0..4].copy_from_slice(&len.to_le_bytes());120121let current_bytes_ptr = bytes.as_ptr() as usize;122let offset = current_bytes_ptr - base_ptr;123124// Here we check the overflow of the buffer offset.125if let Ok(offset) = OffsetType::try_from(offset) {126#[allow(clippy::unnecessary_cast)]127let offset = offset as u32;128payload[12..16].copy_from_slice(&offset.to_le_bytes());129payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());130} else {131let len = base_buffer.len() - offset;132133// Set new buffer134base_buffer = base_buffer.clone().sliced(offset, len);135base_ptr = base_buffer.as_ptr() as usize;136137// And add the (truncated) one to the buffers138buffers.push(truncate_buffer(&base_buffer));139buffer_idx = buffer_idx.checked_add(1).expect("max buffers exceeded");140141let offset = 0u32;142payload[12..16].copy_from_slice(&offset.to_le_bytes());143payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());144}145}146147let value = View::from_le_bytes(payload);148unsafe { views.push_unchecked(value) };149}150let buffers = if uses_buffer {151Arc::from(buffers)152} else {153Arc::from([])154};155unsafe {156BinaryViewArray::new_unchecked_unknown_md(157ArrowDataType::BinaryView,158views.into(),159buffers,160arr.validity().cloned(),161None,162)163}164}165166pub fn utf8_to_utf8view<O: Offset>(arr: &Utf8Array<O>) -> Utf8ViewArray {167unsafe { binary_to_binview(&arr.to_binary()).to_utf8view_unchecked() }168}169170#[cfg(test)]171mod test {172use super::*;173174#[test]175fn overflowing_utf8_to_binview() {176let values = [177"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (offset)178"123", // inline179"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74180"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)181"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74182"234", // inline183"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)184"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74185"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)186"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74187"324", // inline188];189let array = Utf8Array::<i64>::from_slice(values);190191let out = utf8_to_utf8view(&array);192// Ensure we hit the multiple buffers part.193assert_eq!(out.data_buffers().len(), 4);194// Ensure we created a valid binview195let out = out.values_iter().collect::<Vec<_>>();196assert_eq!(out, values);197}198}199200201