Path: blob/main/crates/polars-compute/src/cast/utf8_to.rs
8446 views
use arrow::array::*;1use arrow::datatypes::ArrowDataType;2use arrow::offset::Offset;3use arrow::types::NativeType;4use polars_buffer::Buffer;5use polars_error::PolarsResult;6use polars_utils::vec::PushUnchecked;78pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z";910pub(super) fn utf8_to_dictionary_dyn<O: Offset, K: DictionaryKey>(11from: &dyn Array,12) -> PolarsResult<Box<dyn Array>> {13let values = from.as_any().downcast_ref().unwrap();14utf8_to_dictionary::<O, K>(values).map(|x| Box::new(x) as Box<dyn Array>)15}1617/// Cast [`Utf8Array`] to [`DictionaryArray`], also known as packing.18/// # Errors19/// This function errors if the maximum key is smaller than the number of distinct elements20/// in the array.21pub fn utf8_to_dictionary<O: Offset, K: DictionaryKey>(22from: &Utf8Array<O>,23) -> PolarsResult<DictionaryArray<K>> {24let mut array = MutableDictionaryArray::<K, MutableUtf8Array<O>>::empty_with_value_dtype(25from.dtype().clone(),26);27array.reserve(from.len());28array.try_extend(from.iter())?;2930Ok(array.into())31}3233/// Conversion of utf834pub fn utf8_to_large_utf8(from: &Utf8Array<i32>) -> Utf8Array<i64> {35let dtype = Utf8Array::<i64>::default_dtype();36let validity = from.validity().cloned();37let values = from.values().clone();3839let offsets = from.offsets().into();40// SAFETY: sound because `values` fulfills the same invariants as `from.values()`41unsafe { Utf8Array::<i64>::new_unchecked(dtype, offsets, values, validity) }42}4344/// Conversion of utf845pub fn utf8_large_to_utf8(from: &Utf8Array<i64>) -> PolarsResult<Utf8Array<i32>> {46let dtype = Utf8Array::<i32>::default_dtype();47let validity = from.validity().cloned();48let values = from.values().clone();49let offsets = from.offsets().try_into()?;5051// SAFETY: sound because `values` fulfills the same invariants as `from.values()`52Ok(unsafe { Utf8Array::<i32>::new_unchecked(dtype, offsets, values, validity) })53}5455/// Conversion to binary56pub fn utf8_to_binary<O: Offset>(from: &Utf8Array<O>, to_dtype: ArrowDataType) -> BinaryArray<O> {57// SAFETY: erasure of an invariant is always safe58BinaryArray::<O>::new(59to_dtype,60from.offsets().clone(),61from.values().clone(),62from.validity().cloned(),63)64}6566// Different types to test the overflow path.67#[cfg(not(test))]68type OffsetType = u32;6970// To trigger overflow71#[cfg(test)]72type OffsetType = i8;7374// If we don't do this the GC of binview will trigger. As we will split up buffers into multiple75// chunks so that we don't overflow the offset u32.76fn truncate_buffer(buf: &Buffer<u8>) -> Buffer<u8> {77// * 2, as it must be able to hold u32::MAX offset + u32::MAX len.78let len = std::cmp::min(buf.len(), ((OffsetType::MAX as u64) * 2) as usize);79buf.clone().sliced(..len)80}8182pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {83// Ensure we didn't accidentally set wrong type84#[cfg(not(debug_assertions))]85let _ = std::mem::transmute::<OffsetType, u32>;8687let mut views = Vec::with_capacity(arr.len());88let mut uses_buffer = false;8990let mut base_buffer = arr.values().clone();91// Offset into the buffer92let mut base_ptr = base_buffer.as_ptr() as usize;9394// Offset into the binview buffers95let mut buffer_idx = 0_u32;9697// Binview buffers98// Note that the buffer may look far further than u32::MAX, but as we don't clone data99let mut buffers = vec![truncate_buffer(&base_buffer)];100101for bytes in arr.values_iter() {102let len: u32 = bytes103.len()104.try_into()105.expect("max string/binary length exceeded");106107let mut payload = [0; 16];108payload[0..4].copy_from_slice(&len.to_le_bytes());109110if len <= 12 {111payload[4..4 + bytes.len()].copy_from_slice(bytes);112} else {113uses_buffer = true;114115// Copy the parts we know are correct.116unsafe { payload[4..8].copy_from_slice(bytes.get_unchecked(0..4)) };117payload[0..4].copy_from_slice(&len.to_le_bytes());118119let current_bytes_ptr = bytes.as_ptr() as usize;120let offset = current_bytes_ptr - base_ptr;121122// Here we check the overflow of the buffer offset.123if let Ok(offset) = OffsetType::try_from(offset) {124#[allow(clippy::unnecessary_cast)]125let offset = offset as u32;126payload[12..16].copy_from_slice(&offset.to_le_bytes());127payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());128} else {129let len = base_buffer.len() - offset;130131// Set new buffer132base_buffer = base_buffer.clone().sliced(offset..offset + len);133base_ptr = base_buffer.as_ptr() as usize;134135// And add the (truncated) one to the buffers136buffers.push(truncate_buffer(&base_buffer));137buffer_idx = buffer_idx.checked_add(1).expect("max buffers exceeded");138139let offset = 0u32;140payload[12..16].copy_from_slice(&offset.to_le_bytes());141payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());142}143}144145let value = View::from_le_bytes(payload);146unsafe { views.push_unchecked(value) };147}148let buffers = if uses_buffer {149Buffer::from(buffers)150} else {151Buffer::new()152};153unsafe {154BinaryViewArray::new_unchecked_unknown_md(155ArrowDataType::BinaryView,156views.into(),157buffers,158arr.validity().cloned(),159None,160)161}162}163164pub fn utf8_to_utf8view<O: Offset>(arr: &Utf8Array<O>) -> Utf8ViewArray {165unsafe { binary_to_binview(&arr.to_binary()).to_utf8view_unchecked() }166}167168#[cfg(test)]169mod test {170use super::*;171172#[test]173fn overflowing_utf8_to_binview() {174let values = [175"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (offset)176"123", // inline177"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74178"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)179"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74180"234", // inline181"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)182"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74183"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)184"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74185"324", // inline186];187let array = Utf8Array::<i64>::from_slice(values);188189let out = utf8_to_utf8view(&array);190// Ensure we hit the multiple buffers part.191assert_eq!(out.data_buffers().len(), 4);192// Ensure we created a valid binview193let out = out.values_iter().collect::<Vec<_>>();194assert_eq!(out, values);195}196}197198199