Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/strings/substring.rs
8374 views
1
use arrow::array::View;
2
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
3
use polars_core::prelude::{ChunkFullNull, Int64Chunked, StringChunked, UInt64Chunked};
4
use polars_error::{PolarsResult, polars_ensure};
5
6
fn is_utf8_codepoint_start(b: u8) -> bool {
7
// The top two bits of a continuation byte are 10. Any other value is a
8
// starting byte. We can use signed comparison to test for this in one
9
// instruction, as the top bits 11, 00 and 01 are all more positive and thus
10
// larger in signed comparison.
11
(b as i8) >= (0b1100_0000_u8 as i8)
12
}
13
14
/// Similar to char_to_byte_idx but if `char_idx` would be out-of-bounds the
15
/// number of codepoints in s is returned as an error.
16
pub fn char_to_byte_idx_or_cp_count(s: &str, char_idx: usize) -> Result<usize, usize> {
17
let bytes = s.as_bytes();
18
if char_idx == 0 {
19
return Ok(0);
20
}
21
22
let mut offset = 0;
23
let mut num_chars_seen = 0;
24
25
// Auto-vectorized bulk processing, but skip if index is small.
26
if char_idx >= 16 {
27
while let Some(chunk) = bytes.get(offset..offset + 16) {
28
let chunk_seen: usize = chunk
29
.iter()
30
.map(|b| is_utf8_codepoint_start(*b) as usize)
31
.sum();
32
if num_chars_seen + chunk_seen > char_idx {
33
break;
34
}
35
offset += 16;
36
num_chars_seen += chunk_seen;
37
}
38
}
39
40
while let Some(b) = bytes.get(offset) {
41
num_chars_seen += is_utf8_codepoint_start(*b) as usize;
42
if num_chars_seen > char_idx {
43
return Ok(offset);
44
}
45
offset += 1;
46
}
47
48
debug_assert!(offset == bytes.len());
49
Err(num_chars_seen)
50
}
51
52
/// Given an offset to the start of the `char_idx`th codepoint, returns the
53
/// equivalent offset in bytes.
54
///
55
/// If `char_idx` would be out-of-bounds s.len() is returned.
56
pub fn char_to_byte_idx(s: &str, char_idx: usize) -> usize {
57
if char_idx >= s.len() {
58
// No need to even count.
59
s.len()
60
} else {
61
char_to_byte_idx_or_cp_count(s, char_idx).unwrap_or(s.len())
62
}
63
}
64
65
/// Similar to rev_char_to_byte_idx but if `char_idx` would be out-of-bounds the
66
/// number of codepoints in s is returned as an error.
67
pub fn rev_char_to_byte_idx_or_cp_count(s: &str, rev_char_idx: usize) -> Result<usize, usize> {
68
let bytes = s.as_bytes();
69
if rev_char_idx == 0 {
70
return Ok(bytes.len());
71
}
72
73
let mut offset = s.len();
74
let mut num_chars_seen = 0;
75
76
// Auto-vectorized bulk processing, but skip if index is small.
77
if rev_char_idx >= 16 {
78
while offset >= 16 {
79
let chunk = unsafe { bytes.get_unchecked(offset - 16..offset) };
80
let chunk_seen: usize = chunk
81
.iter()
82
.map(|b| is_utf8_codepoint_start(*b) as usize)
83
.sum();
84
if num_chars_seen + chunk_seen >= rev_char_idx {
85
break;
86
}
87
offset -= 16;
88
num_chars_seen += chunk_seen;
89
}
90
}
91
92
while offset > 0 {
93
offset -= 1;
94
let byte = unsafe { bytes.get_unchecked(offset) };
95
num_chars_seen += is_utf8_codepoint_start(*byte) as usize;
96
if num_chars_seen >= rev_char_idx {
97
return Ok(offset);
98
}
99
}
100
101
debug_assert!(offset == 0);
102
Err(num_chars_seen)
103
}
104
105
/// Counts rev_char_idx code points from *the end* of the string, returning an
106
/// offset in bytes where this codepoint ends.
107
///
108
/// For example, rev_char_to_byte_idx(0, s) returns s.len(), and
109
/// rev_char_to_byte_idx(1, s) returns s.len() - width(last_codepoint_in_s).
110
///
111
/// If rev_char_idx is large enough that we would go out of bounds, 0 is returned.
112
pub fn rev_char_to_byte_idx(s: &str, rev_char_idx: usize) -> usize {
113
if rev_char_idx >= s.len() {
114
// No need to even count.
115
0
116
} else {
117
rev_char_to_byte_idx_or_cp_count(s, rev_char_idx).unwrap_or(0)
118
}
119
}
120
121
fn head_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
122
if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {
123
let end_idx = head_binary_values(str_val, n);
124
Some(unsafe { str_val.get_unchecked(..end_idx) })
125
} else {
126
None
127
}
128
}
129
130
fn head_binary_values(str_val: &str, n: i64) -> usize {
131
if n >= 0 {
132
char_to_byte_idx(str_val, n as usize)
133
} else {
134
rev_char_to_byte_idx(str_val, (-n) as usize)
135
}
136
}
137
138
fn tail_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
139
if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {
140
let start_idx = tail_binary_values(str_val, n);
141
Some(unsafe { str_val.get_unchecked(start_idx..) })
142
} else {
143
None
144
}
145
}
146
147
fn tail_binary_values(str_val: &str, n: i64) -> usize {
148
if n >= 0 {
149
rev_char_to_byte_idx(str_val, n as usize)
150
} else {
151
char_to_byte_idx(str_val, (-n) as usize)
152
}
153
}
154
155
fn substring_ternary_offsets(
156
opt_str_val: Option<&str>,
157
opt_offset: Option<i64>,
158
opt_length: Option<u64>,
159
) -> Option<(usize, usize)> {
160
let str_val = opt_str_val?;
161
let offset = opt_offset?;
162
Some(substring_ternary_offsets_value(
163
str_val,
164
offset,
165
opt_length.unwrap_or(u64::MAX),
166
))
167
}
168
169
pub fn substring_ternary_offsets_value(
170
str_val: &str,
171
offset: i64,
172
mut length: u64,
173
) -> (usize, usize) {
174
// Fast-path: always empty string.
175
if length == 0 || offset >= str_val.len() as i64 {
176
return (0, 0);
177
}
178
179
let start_byte_offset = if offset >= 0 {
180
char_to_byte_idx(str_val, offset as usize)
181
} else {
182
// Fast-path: always empty string.
183
let end_offset_upper_bound = offset
184
.saturating_add(str_val.len() as i64)
185
.saturating_add(length.try_into().unwrap_or(i64::MAX));
186
if end_offset_upper_bound < 0 {
187
return (0, 0);
188
}
189
190
match rev_char_to_byte_idx_or_cp_count(str_val, (-offset) as usize) {
191
Ok(so) => so,
192
Err(n_cp) => {
193
// Our offset was so negative it is before the start of our string.
194
// This means our length must be reduced, assuming it is finite.
195
length = length.saturating_sub((-offset) as u64 - n_cp as u64);
196
0
197
},
198
}
199
};
200
201
let stop_byte_offset = char_to_byte_idx(&str_val[start_byte_offset..], length as usize);
202
(start_byte_offset, start_byte_offset + stop_byte_offset)
203
}
204
205
fn substring_ternary(
206
opt_str_val: Option<&str>,
207
opt_offset: Option<i64>,
208
opt_length: Option<u64>,
209
) -> Option<&str> {
210
let (start, end) = substring_ternary_offsets(opt_str_val, opt_offset, opt_length)?;
211
unsafe { opt_str_val.map(|str_val| str_val.get_unchecked(start..end)) }
212
}
213
214
pub fn update_view(mut view: View, start: usize, end: usize, val: &str) -> View {
215
let length = (end - start) as u32;
216
view.length = length;
217
218
// SAFETY: we just compute the start /end.
219
let subval = unsafe { val.get_unchecked(start..end).as_bytes() };
220
221
if length <= 12 {
222
View::new_inline(subval)
223
} else {
224
view.offset += start as u32;
225
view.length = length;
226
view.prefix = u32::from_le_bytes(subval[0..4].try_into().unwrap());
227
view
228
}
229
}
230
231
pub(super) fn substring(
232
ca: &StringChunked,
233
offset: &Int64Chunked,
234
length: &UInt64Chunked,
235
) -> StringChunked {
236
match (ca.len(), offset.len(), length.len()) {
237
(1, 1, _) => {
238
let str_val = ca.get(0);
239
let offset = offset.get(0);
240
unary_elementwise(length, |length| substring_ternary(str_val, offset, length))
241
.with_name(ca.name().clone())
242
},
243
(_, 1, 1) => {
244
let offset = offset.get(0);
245
let length = length.get(0).unwrap_or(u64::MAX);
246
247
let Some(offset) = offset else {
248
return StringChunked::full_null(ca.name().clone(), ca.len());
249
};
250
251
unsafe {
252
ca.apply_views(|view, val| {
253
let (start, end) = substring_ternary_offsets_value(val, offset, length);
254
update_view(view, start, end, val)
255
})
256
}
257
},
258
(1, _, 1) => {
259
let str_val = ca.get(0);
260
let length = length.get(0);
261
unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length))
262
.with_name(ca.name().clone())
263
},
264
(1, len_b, len_c) if len_b == len_c => {
265
let str_val = ca.get(0);
266
binary_elementwise(offset, length, |offset, length| {
267
substring_ternary(str_val, offset, length)
268
})
269
},
270
(len_a, 1, len_c) if len_a == len_c => {
271
fn infer<F: for<'a> FnMut(Option<&'a str>, Option<u64>) -> Option<&'a str>>(f: F) -> F where
272
{
273
f
274
}
275
let offset = offset.get(0);
276
binary_elementwise(
277
ca,
278
length,
279
infer(|str_val, length| substring_ternary(str_val, offset, length)),
280
)
281
},
282
(len_a, len_b, 1) if len_a == len_b => {
283
fn infer<F: for<'a> FnMut(Option<&'a str>, Option<i64>) -> Option<&'a str>>(f: F) -> F where
284
{
285
f
286
}
287
let length = length.get(0);
288
binary_elementwise(
289
ca,
290
offset,
291
infer(|str_val, offset| substring_ternary(str_val, offset, length)),
292
)
293
},
294
_ => ternary_elementwise(ca, offset, length, substring_ternary),
295
}
296
}
297
298
pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult<StringChunked> {
299
match (ca.len(), n.len()) {
300
(len, 1) => {
301
let n = n.get(0);
302
let Some(n) = n else {
303
return Ok(StringChunked::full_null(ca.name().clone(), len));
304
};
305
306
Ok(unsafe {
307
ca.apply_views(|view, val| {
308
let end = head_binary_values(val, n);
309
update_view(view, 0, end, val)
310
})
311
})
312
},
313
// TODO! below should also work on only views
314
(1, _) => {
315
let str_val = ca.get(0);
316
Ok(unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name().clone()))
317
},
318
(a, b) => {
319
polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.head' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);
320
Ok(binary_elementwise(ca, n, head_binary))
321
},
322
}
323
}
324
325
pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult<StringChunked> {
326
Ok(match (ca.len(), n.len()) {
327
(len, 1) => {
328
let n = n.get(0);
329
let Some(n) = n else {
330
return Ok(StringChunked::full_null(ca.name().clone(), len));
331
};
332
unsafe {
333
ca.apply_views(|view, val| {
334
let start = tail_binary_values(val, n);
335
update_view(view, start, val.len(), val)
336
})
337
}
338
},
339
// TODO! below should also work on only views
340
(1, _) => {
341
let str_val = ca.get(0);
342
unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name().clone())
343
},
344
(a, b) => {
345
polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.tail' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);
346
binary_elementwise(ca, n, tail_binary)
347
},
348
})
349
}
350
351