Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/strings/substring.rs
6939 views
1
use std::cmp::Ordering;
2
3
use arrow::array::View;
4
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
5
use polars_core::prelude::{ChunkFullNull, Int64Chunked, StringChunked, UInt64Chunked};
6
use polars_error::{PolarsResult, polars_ensure};
7
8
fn head_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
9
if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {
10
let end_idx = head_binary_values(str_val, n);
11
Some(unsafe { str_val.get_unchecked(..end_idx) })
12
} else {
13
None
14
}
15
}
16
17
fn head_binary_values(str_val: &str, n: i64) -> usize {
18
match n.cmp(&0) {
19
Ordering::Equal => 0,
20
Ordering::Greater => {
21
if n as usize >= str_val.len() {
22
return str_val.len();
23
}
24
// End after the nth codepoint.
25
str_val
26
.char_indices()
27
.nth(n as usize)
28
.map(|(idx, _)| idx)
29
.unwrap_or(str_val.len())
30
},
31
_ => {
32
// End after the nth codepoint from the end.
33
str_val
34
.char_indices()
35
.rev()
36
.nth((-n - 1) as usize)
37
.map(|(idx, _)| idx)
38
.unwrap_or(0)
39
},
40
}
41
}
42
43
fn tail_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
44
if let (Some(str_val), Some(n)) = (opt_str_val, opt_n) {
45
let start_idx = tail_binary_values(str_val, n);
46
Some(unsafe { str_val.get_unchecked(start_idx..) })
47
} else {
48
None
49
}
50
}
51
52
fn tail_binary_values(str_val: &str, n: i64) -> usize {
53
// `max_len` is guaranteed to be at least the total number of characters.
54
let max_len = str_val.len();
55
56
match n.cmp(&0) {
57
Ordering::Equal => max_len,
58
Ordering::Greater => {
59
if n as usize >= max_len {
60
return 0;
61
}
62
// Start from nth codepoint from the end
63
str_val
64
.char_indices()
65
.rev()
66
.nth((n - 1) as usize)
67
.map(|(idx, _)| idx)
68
.unwrap_or(0)
69
},
70
_ => {
71
// Start after the nth codepoint
72
str_val
73
.char_indices()
74
.nth((-n) as usize)
75
.map(|(idx, _)| idx)
76
.unwrap_or(max_len)
77
},
78
}
79
}
80
81
fn substring_ternary_offsets(
82
opt_str_val: Option<&str>,
83
opt_offset: Option<i64>,
84
opt_length: Option<u64>,
85
) -> Option<(usize, usize)> {
86
let str_val = opt_str_val?;
87
let offset = opt_offset?;
88
Some(substring_ternary_offsets_value(
89
str_val,
90
offset,
91
opt_length.unwrap_or(u64::MAX),
92
))
93
}
94
95
pub fn substring_ternary_offsets_value(str_val: &str, offset: i64, length: u64) -> (usize, usize) {
96
// Fast-path: always empty string.
97
if length == 0 || offset >= str_val.len() as i64 {
98
return (0, 0);
99
}
100
101
let mut indices = str_val.char_indices().map(|(o, _)| o);
102
let mut length_reduction = 0;
103
let start_byte_offset = if offset >= 0 {
104
indices.nth(offset as usize).unwrap_or(str_val.len())
105
} else {
106
// If `offset` is negative, it counts from the end of the string.
107
let mut chars_skipped = 0;
108
let found = indices
109
.inspect(|_| chars_skipped += 1)
110
.nth_back((-offset - 1) as usize);
111
112
// If we didn't find our char that means our offset was so negative it
113
// is before the start of our string. This means our length must be
114
// reduced, assuming it is finite.
115
if let Some(off) = found {
116
off
117
} else {
118
length_reduction = (-offset) as usize - chars_skipped;
119
0
120
}
121
};
122
123
let str_val = &str_val[start_byte_offset..];
124
let mut indices = str_val.char_indices().map(|(o, _)| o);
125
let stop_byte_offset = indices
126
.nth((length as usize).saturating_sub(length_reduction))
127
.unwrap_or(str_val.len());
128
(start_byte_offset, stop_byte_offset + start_byte_offset)
129
}
130
131
fn substring_ternary(
132
opt_str_val: Option<&str>,
133
opt_offset: Option<i64>,
134
opt_length: Option<u64>,
135
) -> Option<&str> {
136
let (start, end) = substring_ternary_offsets(opt_str_val, opt_offset, opt_length)?;
137
unsafe { opt_str_val.map(|str_val| str_val.get_unchecked(start..end)) }
138
}
139
140
pub fn update_view(mut view: View, start: usize, end: usize, val: &str) -> View {
141
let length = (end - start) as u32;
142
view.length = length;
143
144
// SAFETY: we just compute the start /end.
145
let subval = unsafe { val.get_unchecked(start..end).as_bytes() };
146
147
if length <= 12 {
148
View::new_inline(subval)
149
} else {
150
view.offset += start as u32;
151
view.length = length;
152
view.prefix = u32::from_le_bytes(subval[0..4].try_into().unwrap());
153
view
154
}
155
}
156
157
pub(super) fn substring(
158
ca: &StringChunked,
159
offset: &Int64Chunked,
160
length: &UInt64Chunked,
161
) -> StringChunked {
162
match (ca.len(), offset.len(), length.len()) {
163
(1, 1, _) => {
164
let str_val = ca.get(0);
165
let offset = offset.get(0);
166
unary_elementwise(length, |length| substring_ternary(str_val, offset, length))
167
.with_name(ca.name().clone())
168
},
169
(_, 1, 1) => {
170
let offset = offset.get(0);
171
let length = length.get(0).unwrap_or(u64::MAX);
172
173
let Some(offset) = offset else {
174
return StringChunked::full_null(ca.name().clone(), ca.len());
175
};
176
177
unsafe {
178
ca.apply_views(|view, val| {
179
let (start, end) = substring_ternary_offsets_value(val, offset, length);
180
update_view(view, start, end, val)
181
})
182
}
183
},
184
(1, _, 1) => {
185
let str_val = ca.get(0);
186
let length = length.get(0);
187
unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length))
188
.with_name(ca.name().clone())
189
},
190
(1, len_b, len_c) if len_b == len_c => {
191
let str_val = ca.get(0);
192
binary_elementwise(offset, length, |offset, length| {
193
substring_ternary(str_val, offset, length)
194
})
195
},
196
(len_a, 1, len_c) if len_a == len_c => {
197
fn infer<F: for<'a> FnMut(Option<&'a str>, Option<u64>) -> Option<&'a str>>(f: F) -> F where
198
{
199
f
200
}
201
let offset = offset.get(0);
202
binary_elementwise(
203
ca,
204
length,
205
infer(|str_val, length| substring_ternary(str_val, offset, length)),
206
)
207
},
208
(len_a, len_b, 1) if len_a == len_b => {
209
fn infer<F: for<'a> FnMut(Option<&'a str>, Option<i64>) -> Option<&'a str>>(f: F) -> F where
210
{
211
f
212
}
213
let length = length.get(0);
214
binary_elementwise(
215
ca,
216
offset,
217
infer(|str_val, offset| substring_ternary(str_val, offset, length)),
218
)
219
},
220
_ => ternary_elementwise(ca, offset, length, substring_ternary),
221
}
222
}
223
224
pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult<StringChunked> {
225
match (ca.len(), n.len()) {
226
(len, 1) => {
227
let n = n.get(0);
228
let Some(n) = n else {
229
return Ok(StringChunked::full_null(ca.name().clone(), len));
230
};
231
232
Ok(unsafe {
233
ca.apply_views(|view, val| {
234
let end = head_binary_values(val, n);
235
update_view(view, 0, end, val)
236
})
237
})
238
},
239
// TODO! below should also work on only views
240
(1, _) => {
241
let str_val = ca.get(0);
242
Ok(unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name().clone()))
243
},
244
(a, b) => {
245
polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.head' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);
246
Ok(binary_elementwise(ca, n, head_binary))
247
},
248
}
249
}
250
251
pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult<StringChunked> {
252
Ok(match (ca.len(), n.len()) {
253
(len, 1) => {
254
let n = n.get(0);
255
let Some(n) = n else {
256
return Ok(StringChunked::full_null(ca.name().clone(), len));
257
};
258
unsafe {
259
ca.apply_views(|view, val| {
260
let start = tail_binary_values(val, n);
261
update_view(view, start, val.len(), val)
262
})
263
}
264
},
265
// TODO! below should also work on only views
266
(1, _) => {
267
let str_val = ca.get(0);
268
unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name().clone())
269
},
270
(a, b) => {
271
polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.tail' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);
272
binary_elementwise(ca, n, tail_binary)
273
},
274
})
275
}
276
277