Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/cast/utf8_to.rs
6939 views
1
use std::sync::Arc;
2
3
use arrow::array::*;
4
use arrow::buffer::Buffer;
5
use arrow::datatypes::ArrowDataType;
6
use arrow::offset::Offset;
7
use arrow::types::NativeType;
8
use polars_error::PolarsResult;
9
use polars_utils::vec::PushUnchecked;
10
11
pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z";
12
13
pub(super) fn utf8_to_dictionary_dyn<O: Offset, K: DictionaryKey>(
14
from: &dyn Array,
15
) -> PolarsResult<Box<dyn Array>> {
16
let values = from.as_any().downcast_ref().unwrap();
17
utf8_to_dictionary::<O, K>(values).map(|x| Box::new(x) as Box<dyn Array>)
18
}
19
20
/// Cast [`Utf8Array`] to [`DictionaryArray`], also known as packing.
21
/// # Errors
22
/// This function errors if the maximum key is smaller than the number of distinct elements
23
/// in the array.
24
pub fn utf8_to_dictionary<O: Offset, K: DictionaryKey>(
25
from: &Utf8Array<O>,
26
) -> PolarsResult<DictionaryArray<K>> {
27
let mut array = MutableDictionaryArray::<K, MutableUtf8Array<O>>::new();
28
array.reserve(from.len());
29
array.try_extend(from.iter())?;
30
31
Ok(array.into())
32
}
33
34
/// Conversion of utf8
35
pub fn utf8_to_large_utf8(from: &Utf8Array<i32>) -> Utf8Array<i64> {
36
let dtype = Utf8Array::<i64>::default_dtype();
37
let validity = from.validity().cloned();
38
let values = from.values().clone();
39
40
let offsets = from.offsets().into();
41
// SAFETY: sound because `values` fulfills the same invariants as `from.values()`
42
unsafe { Utf8Array::<i64>::new_unchecked(dtype, offsets, values, validity) }
43
}
44
45
/// Conversion of utf8
46
pub fn utf8_large_to_utf8(from: &Utf8Array<i64>) -> PolarsResult<Utf8Array<i32>> {
47
let dtype = Utf8Array::<i32>::default_dtype();
48
let validity = from.validity().cloned();
49
let values = from.values().clone();
50
let offsets = from.offsets().try_into()?;
51
52
// SAFETY: sound because `values` fulfills the same invariants as `from.values()`
53
Ok(unsafe { Utf8Array::<i32>::new_unchecked(dtype, offsets, values, validity) })
54
}
55
56
/// Conversion to binary
57
pub fn utf8_to_binary<O: Offset>(from: &Utf8Array<O>, to_dtype: ArrowDataType) -> BinaryArray<O> {
58
// SAFETY: erasure of an invariant is always safe
59
BinaryArray::<O>::new(
60
to_dtype,
61
from.offsets().clone(),
62
from.values().clone(),
63
from.validity().cloned(),
64
)
65
}
66
67
// Different types to test the overflow path.
68
#[cfg(not(test))]
69
type OffsetType = u32;
70
71
// To trigger overflow
72
#[cfg(test)]
73
type OffsetType = i8;
74
75
// If we don't do this the GC of binview will trigger. As we will split up buffers into multiple
76
// chunks so that we don't overflow the offset u32.
77
fn truncate_buffer(buf: &Buffer<u8>) -> Buffer<u8> {
78
// * 2, as it must be able to hold u32::MAX offset + u32::MAX len.
79
buf.clone().sliced(
80
0,
81
std::cmp::min(buf.len(), ((OffsetType::MAX as u64) * 2) as usize),
82
)
83
}
84
85
pub fn binary_to_binview<O: Offset>(arr: &BinaryArray<O>) -> BinaryViewArray {
86
// Ensure we didn't accidentally set wrong type
87
#[cfg(not(debug_assertions))]
88
let _ = std::mem::transmute::<OffsetType, u32>;
89
90
let mut views = Vec::with_capacity(arr.len());
91
let mut uses_buffer = false;
92
93
let mut base_buffer = arr.values().clone();
94
// Offset into the buffer
95
let mut base_ptr = base_buffer.as_ptr() as usize;
96
97
// Offset into the binview buffers
98
let mut buffer_idx = 0_u32;
99
100
// Binview buffers
101
// Note that the buffer may look far further than u32::MAX, but as we don't clone data
102
let mut buffers = vec![truncate_buffer(&base_buffer)];
103
104
for bytes in arr.values_iter() {
105
let len: u32 = bytes
106
.len()
107
.try_into()
108
.expect("max string/binary length exceeded");
109
110
let mut payload = [0; 16];
111
payload[0..4].copy_from_slice(&len.to_le_bytes());
112
113
if len <= 12 {
114
payload[4..4 + bytes.len()].copy_from_slice(bytes);
115
} else {
116
uses_buffer = true;
117
118
// Copy the parts we know are correct.
119
unsafe { payload[4..8].copy_from_slice(bytes.get_unchecked(0..4)) };
120
payload[0..4].copy_from_slice(&len.to_le_bytes());
121
122
let current_bytes_ptr = bytes.as_ptr() as usize;
123
let offset = current_bytes_ptr - base_ptr;
124
125
// Here we check the overflow of the buffer offset.
126
if let Ok(offset) = OffsetType::try_from(offset) {
127
#[allow(clippy::unnecessary_cast)]
128
let offset = offset as u32;
129
payload[12..16].copy_from_slice(&offset.to_le_bytes());
130
payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());
131
} else {
132
let len = base_buffer.len() - offset;
133
134
// Set new buffer
135
base_buffer = base_buffer.clone().sliced(offset, len);
136
base_ptr = base_buffer.as_ptr() as usize;
137
138
// And add the (truncated) one to the buffers
139
buffers.push(truncate_buffer(&base_buffer));
140
buffer_idx = buffer_idx.checked_add(1).expect("max buffers exceeded");
141
142
let offset = 0u32;
143
payload[12..16].copy_from_slice(&offset.to_le_bytes());
144
payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes());
145
}
146
}
147
148
let value = View::from_le_bytes(payload);
149
unsafe { views.push_unchecked(value) };
150
}
151
let buffers = if uses_buffer {
152
Arc::from(buffers)
153
} else {
154
Arc::from([])
155
};
156
unsafe {
157
BinaryViewArray::new_unchecked_unknown_md(
158
ArrowDataType::BinaryView,
159
views.into(),
160
buffers,
161
arr.validity().cloned(),
162
None,
163
)
164
}
165
}
166
167
pub fn utf8_to_utf8view<O: Offset>(arr: &Utf8Array<O>) -> Utf8ViewArray {
168
unsafe { binary_to_binview(&arr.to_binary()).to_utf8view_unchecked() }
169
}
170
171
#[cfg(test)]
172
mod test {
173
use super::*;
174
175
#[test]
176
fn overflowing_utf8_to_binview() {
177
let values = [
178
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (offset)
179
"123", // inline
180
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74
181
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)
182
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74
183
"234", // inline
184
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)
185
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74
186
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 0 (new buffer)
187
"lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", // 74
188
"324", // inline
189
];
190
let array = Utf8Array::<i64>::from_slice(values);
191
192
let out = utf8_to_utf8view(&array);
193
// Ensure we hit the multiple buffers part.
194
assert_eq!(out.data_buffers().len(), 4);
195
// Ensure we created a valid binview
196
let out = out.values_iter().collect::<Vec<_>>();
197
assert_eq!(out, values);
198
}
199
}
200
201