Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-row/src/variable/utf8.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
//! Row encoding for UTF-8 strings
3
//!
4
//! This encoding is based on the fact that in UTF-8 the bytes 0xFC - 0xFF are never valid bytes.
5
//! To make this work with the row encoding, we add 2 to each byte which gives us two bytes which
6
//! never occur in UTF-8 before and after the possible byte range. The values 0x00 and 0xFF are
7
//! reserved for the null sentinel. The values 0x01 and 0xFE are reserved as a sequence terminator
8
//! byte.
9
//!
10
//! This allows the string row encoding to have a constant 1 byte overhead.
11
use std::mem::MaybeUninit;
12
13
use arrow::array::{MutableBinaryViewArray, PrimitiveArray, Utf8ViewArray};
14
use arrow::bitmap::BitmapBuilder;
15
use arrow::types::NativeType;
16
use polars_dtype::categorical::{CatNative, CategoricalMapping};
17
18
use crate::row::RowEncodingOptions;
19
20
#[inline]
21
pub fn len_from_item(a: Option<usize>, _opt: RowEncodingOptions) -> usize {
22
// Length = 1 i.f.f. str is null
23
// Length = len(str) + 1 i.f.f. str is non-null
24
1 + a.unwrap_or_default()
25
}
26
27
pub unsafe fn len_from_buffer(row: &[u8], opt: RowEncodingOptions) -> usize {
28
// null
29
if *row.get_unchecked(0) == opt.null_sentinel() {
30
return 1;
31
}
32
33
let end = if opt.contains(RowEncodingOptions::DESCENDING) {
34
unsafe { row.iter().position(|&b| b == 0xFE).unwrap_unchecked() }
35
} else {
36
unsafe { row.iter().position(|&b| b == 0x01).unwrap_unchecked() }
37
};
38
39
end + 1
40
}
41
42
pub unsafe fn encode_str<'a, I: Iterator<Item = Option<&'a str>>>(
43
buffer: &mut [MaybeUninit<u8>],
44
input: I,
45
opt: RowEncodingOptions,
46
offsets: &mut [usize],
47
) {
48
let null_sentinel = opt.null_sentinel();
49
let t = if opt.contains(RowEncodingOptions::DESCENDING) {
50
0xFF
51
} else {
52
0x00
53
};
54
55
for (offset, opt_value) in offsets.iter_mut().zip(input) {
56
let dst = buffer.get_unchecked_mut(*offset..);
57
58
match opt_value {
59
None => {
60
*unsafe { dst.get_unchecked_mut(0) } = MaybeUninit::new(null_sentinel);
61
*offset += 1;
62
},
63
Some(s) => {
64
for (i, &b) in s.as_bytes().iter().enumerate() {
65
*unsafe { dst.get_unchecked_mut(i) } = MaybeUninit::new(t ^ (b + 2));
66
}
67
*unsafe { dst.get_unchecked_mut(s.len()) } = MaybeUninit::new(t ^ 0x01);
68
*offset += 1 + s.len();
69
},
70
}
71
}
72
}
73
74
pub unsafe fn decode_str(rows: &mut [&[u8]], opt: RowEncodingOptions) -> Utf8ViewArray {
75
let null_sentinel = opt.null_sentinel();
76
let descending = opt.contains(RowEncodingOptions::DESCENDING);
77
78
let num_rows = rows.len();
79
let mut array = MutableBinaryViewArray::<str>::with_capacity(rows.len());
80
81
let mut scratch = Vec::new();
82
for row in rows.iter_mut() {
83
let sentinel = *unsafe { row.get_unchecked(0) };
84
if sentinel == null_sentinel {
85
*row = unsafe { row.get_unchecked(1..) };
86
break;
87
}
88
89
scratch.clear();
90
if descending {
91
scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2));
92
} else {
93
scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2));
94
}
95
96
*row = row.get_unchecked(1 + scratch.len()..);
97
array.push_value_ignore_validity(unsafe { std::str::from_utf8_unchecked(&scratch) });
98
}
99
100
if array.len() == num_rows {
101
return array.into();
102
}
103
104
let mut validity = BitmapBuilder::with_capacity(num_rows);
105
validity.extend_constant(array.len(), true);
106
validity.push(false);
107
array.push_value_ignore_validity("");
108
109
for row in rows[array.len()..].iter_mut() {
110
let sentinel = *unsafe { row.get_unchecked(0) };
111
validity.push(sentinel != null_sentinel);
112
if sentinel == null_sentinel {
113
*row = unsafe { row.get_unchecked(1..) };
114
array.push_value_ignore_validity("");
115
continue;
116
}
117
118
scratch.clear();
119
if descending {
120
scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2));
121
} else {
122
scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2));
123
}
124
125
*row = row.get_unchecked(1 + scratch.len()..);
126
array.push_value_ignore_validity(unsafe { std::str::from_utf8_unchecked(&scratch) });
127
}
128
129
let out: Utf8ViewArray = array.into();
130
out.with_validity(validity.into_opt_validity())
131
}
132
133
/// The same as decode_str but inserts it into the given mapping, translating
134
/// it to physical type T.
135
pub unsafe fn decode_str_as_cat<T: NativeType + CatNative>(
136
rows: &mut [&[u8]],
137
opt: RowEncodingOptions,
138
mapping: &CategoricalMapping,
139
) -> PrimitiveArray<T> {
140
let null_sentinel = opt.null_sentinel();
141
let descending = opt.contains(RowEncodingOptions::DESCENDING);
142
143
let num_rows = rows.len();
144
let mut out = Vec::<T>::with_capacity(rows.len());
145
146
let mut scratch = Vec::new();
147
for row in rows.iter_mut() {
148
let sentinel = *unsafe { row.get_unchecked(0) };
149
if sentinel == null_sentinel {
150
*row = unsafe { row.get_unchecked(1..) };
151
break;
152
}
153
154
scratch.clear();
155
if descending {
156
scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2));
157
} else {
158
scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2));
159
}
160
161
*row = row.get_unchecked(1 + scratch.len()..);
162
let s = unsafe { std::str::from_utf8_unchecked(&scratch) };
163
out.push(T::from_cat(mapping.insert_cat(s).unwrap()));
164
}
165
166
if out.len() == num_rows {
167
return PrimitiveArray::from_vec(out);
168
}
169
170
let mut validity = BitmapBuilder::with_capacity(num_rows);
171
validity.extend_constant(out.len(), true);
172
validity.push(false);
173
out.push(T::zeroed());
174
175
for row in rows[out.len()..].iter_mut() {
176
let sentinel = *unsafe { row.get_unchecked(0) };
177
validity.push(sentinel != null_sentinel);
178
if sentinel == null_sentinel {
179
*row = unsafe { row.get_unchecked(1..) };
180
out.push(T::zeroed());
181
continue;
182
}
183
184
scratch.clear();
185
if descending {
186
scratch.extend(row.iter().take_while(|&b| *b != 0xFE).map(|&v| !v - 2));
187
} else {
188
scratch.extend(row.iter().take_while(|&b| *b != 0x01).map(|&v| v - 2));
189
}
190
191
*row = row.get_unchecked(1 + scratch.len()..);
192
let s = unsafe { std::str::from_utf8_unchecked(&scratch) };
193
out.push(T::from_cat(mapping.insert_cat(s).unwrap()));
194
}
195
196
PrimitiveArray::from_vec(out).with_validity(validity.into_opt_validity())
197
}
198
199