Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/strings/find_many.rs
8420 views
1
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
2
use arrow::array::Utf8ViewArray;
3
use polars_core::prelude::arity::unary_elementwise;
4
use polars_core::prelude::*;
5
use polars_core::utils::align_chunks_binary;
6
7
fn build_ac(
8
patterns: &StringChunked,
9
ascii_case_insensitive: bool,
10
leftmost: bool,
11
) -> PolarsResult<AhoCorasick> {
12
AhoCorasickBuilder::new()
13
.match_kind(if leftmost {
14
MatchKind::LeftmostFirst
15
} else {
16
MatchKind::Standard
17
})
18
.ascii_case_insensitive(ascii_case_insensitive)
19
.build(patterns.downcast_iter().flatten().flatten())
20
.map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e))
21
}
22
23
fn build_ac_arr(
24
patterns: &Utf8ViewArray,
25
ascii_case_insensitive: bool,
26
leftmost: bool,
27
) -> PolarsResult<AhoCorasick> {
28
AhoCorasickBuilder::new()
29
.match_kind(if leftmost {
30
MatchKind::LeftmostFirst
31
} else {
32
MatchKind::Standard
33
})
34
.ascii_case_insensitive(ascii_case_insensitive)
35
.build(patterns.into_iter().flatten())
36
.map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e))
37
}
38
39
pub fn contains_any(
40
ca: &StringChunked,
41
patterns: &ListChunked,
42
ascii_case_insensitive: bool,
43
) -> PolarsResult<BooleanChunked> {
44
polars_ensure!(
45
ca.len() == patterns.len() || ca.len() == 1 || patterns.len() == 1,
46
length_mismatch = "str.contains_any",
47
ca.len(),
48
patterns.len()
49
);
50
polars_ensure!(
51
patterns.len() == 1,
52
nyi = "`str.contains_any` with a pattern per row"
53
);
54
55
if patterns.has_nulls() {
56
return Ok(BooleanChunked::full_null(ca.name().clone(), ca.len()));
57
}
58
59
let patterns = patterns.explode(ExplodeOptions {
60
empty_as_null: false,
61
keep_nulls: true,
62
})?;
63
let patterns = patterns.str()?;
64
let ac = build_ac(patterns, ascii_case_insensitive, false)?;
65
66
Ok(unary_elementwise(ca, |opt_val| {
67
opt_val.map(|val| ac.find(val).is_some())
68
}))
69
}
70
71
pub fn replace_all(
72
ca: &StringChunked,
73
patterns: &ListChunked,
74
replace_with: &ListChunked,
75
ascii_case_insensitive: bool,
76
leftmost: bool,
77
) -> PolarsResult<StringChunked> {
78
let mut length = 1;
79
for (argument_idx, (argument, l)) in [
80
("self", ca.len()),
81
("patterns", patterns.len()),
82
("replace_with", replace_with.len()),
83
]
84
.into_iter()
85
.enumerate()
86
{
87
if l != 1 {
88
if l != length && length != 1 {
89
polars_bail!(
90
length_mismatch = "str.replace_many",
91
l,
92
length,
93
argument = argument,
94
argument_idx = argument_idx
95
);
96
}
97
length = l;
98
}
99
}
100
101
polars_ensure!(
102
patterns.len() == 1 && replace_with.len() == 1,
103
nyi = "`str.replace_many` with a pattern per row"
104
);
105
106
if patterns.has_nulls() || replace_with.has_nulls() {
107
return Ok(StringChunked::full_null(ca.name().clone(), ca.len()));
108
}
109
110
let patterns = patterns.explode(ExplodeOptions {
111
empty_as_null: false,
112
keep_nulls: true,
113
})?;
114
let patterns = patterns.str()?;
115
let replace_with = replace_with.explode(ExplodeOptions {
116
empty_as_null: false,
117
keep_nulls: true,
118
})?;
119
let replace_with = replace_with.str()?;
120
121
let replace_with = if replace_with.len() == 1 && patterns.len() > 1 {
122
replace_with.new_from_index(0, patterns.len())
123
} else {
124
replace_with.clone()
125
};
126
127
polars_ensure!(patterns.len() == replace_with.len(), InvalidOperation: "expected the same amount of patterns as replacement strings");
128
polars_ensure!(patterns.null_count() == 0 && replace_with.null_count() == 0, InvalidOperation: "'patterns'/'replace_with' should not have nulls");
129
let replace_with = replace_with
130
.downcast_iter()
131
.flatten()
132
.flatten()
133
.collect::<Vec<_>>();
134
135
let ac = build_ac(patterns, ascii_case_insensitive, leftmost)?;
136
137
Ok(unary_elementwise(ca, |opt_val| {
138
opt_val.map(|val| ac.replace_all(val, replace_with.as_slice()))
139
}))
140
}
141
142
fn push_str(
143
val: &str,
144
builder: &mut ListStringChunkedBuilder,
145
ac: &AhoCorasick,
146
overlapping: bool,
147
) {
148
if overlapping {
149
let iter = ac.find_overlapping_iter(val);
150
let iter = iter.map(|m| &val[m.start()..m.end()]);
151
builder.append_values_iter(iter);
152
} else {
153
let iter = ac.find_iter(val);
154
let iter = iter.map(|m| &val[m.start()..m.end()]);
155
builder.append_values_iter(iter);
156
}
157
}
158
159
pub fn extract_many(
160
ca: &StringChunked,
161
patterns: &ListChunked,
162
ascii_case_insensitive: bool,
163
overlapping: bool,
164
leftmost: bool,
165
) -> PolarsResult<ListChunked> {
166
// ensure that either overlapping == false, or overlapping == true and leftmost == false
167
polars_ensure!(!overlapping | !leftmost, InvalidOperation: "can not match overlapping patterns when leftmost == True");
168
match (ca.len(), patterns.len()) {
169
(1, _) => match ca.get(0) {
170
None => Ok(ListChunked::full_null_with_dtype(
171
ca.name().clone(),
172
ca.len(),
173
&DataType::String,
174
)),
175
Some(val) => {
176
let mut builder =
177
ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);
178
179
for pat in patterns.amortized_iter() {
180
match pat {
181
None => builder.append_null(),
182
Some(pat) => {
183
let pat = pat.as_ref();
184
let pat = pat.str()?;
185
let pat = pat.rechunk();
186
let pat = pat.downcast_as_array();
187
let ac = build_ac_arr(pat, ascii_case_insensitive, leftmost)?;
188
push_str(val, &mut builder, &ac, overlapping);
189
},
190
}
191
}
192
Ok(builder.finish())
193
},
194
},
195
(_, 1) => {
196
let patterns = patterns.explode(ExplodeOptions {
197
empty_as_null: false,
198
keep_nulls: true,
199
})?;
200
let patterns = patterns.str()?;
201
let ac = build_ac(patterns, ascii_case_insensitive, leftmost)?;
202
let mut builder =
203
ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);
204
205
for arr in ca.downcast_iter() {
206
for opt_val in arr.into_iter() {
207
if let Some(val) = opt_val {
208
push_str(val, &mut builder, &ac, overlapping);
209
} else {
210
builder.append_null();
211
}
212
}
213
}
214
Ok(builder.finish())
215
},
216
(a, b) if a == b => {
217
let mut builder =
218
ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);
219
let (ca, patterns) = align_chunks_binary(ca, patterns);
220
221
for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {
222
for z in arr.into_iter().zip(pat_arr.into_iter()) {
223
match z {
224
(None, _) | (_, None) => builder.append_null(),
225
(Some(val), Some(pat)) => {
226
let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
227
let ac = build_ac_arr(pat, ascii_case_insensitive, leftmost)?;
228
push_str(val, &mut builder, &ac, overlapping);
229
},
230
}
231
}
232
}
233
Ok(builder.finish())
234
},
235
(a, b) => polars_bail!(length_mismatch = "str.extract_many", a, b),
236
}
237
}
238
239
type B = ListPrimitiveChunkedBuilder<UInt32Type>;
240
fn push_idx(val: &str, builder: &mut B, ac: &AhoCorasick, overlapping: bool) {
241
if overlapping {
242
let iter = ac.find_overlapping_iter(val);
243
let iter = iter.map(|m| m.start() as u32);
244
builder.append_values_iter(iter);
245
} else {
246
let iter = ac.find_iter(val);
247
let iter = iter.map(|m| m.start() as u32);
248
builder.append_values_iter(iter);
249
}
250
}
251
252
pub fn find_many(
253
ca: &StringChunked,
254
patterns: &ListChunked,
255
ascii_case_insensitive: bool,
256
overlapping: bool,
257
leftmost: bool,
258
) -> PolarsResult<ListChunked> {
259
polars_ensure!(!overlapping | !leftmost, InvalidOperation: "can not match overlapping patterns when leftmost == True");
260
type B = ListPrimitiveChunkedBuilder<UInt32Type>;
261
match (ca.len(), patterns.len()) {
262
(1, _) => match ca.get(0) {
263
None => Ok(ListChunked::full_null_with_dtype(
264
ca.name().clone(),
265
patterns.len(),
266
&DataType::UInt32,
267
)),
268
Some(val) => {
269
let mut builder = B::new(
270
ca.name().clone(),
271
patterns.len(),
272
patterns.len() * 2,
273
DataType::UInt32,
274
);
275
for pat in patterns.amortized_iter() {
276
match pat {
277
None => builder.append_null(),
278
Some(pat) => {
279
let pat = pat.as_ref();
280
let pat = pat.str()?;
281
let pat = pat.rechunk();
282
let pat = pat.downcast_as_array();
283
let ac = build_ac_arr(pat, ascii_case_insensitive, leftmost)?;
284
push_idx(val, &mut builder, &ac, overlapping);
285
},
286
}
287
}
288
Ok(builder.finish())
289
},
290
},
291
(_, 1) => {
292
let patterns = patterns.explode(ExplodeOptions {
293
empty_as_null: false,
294
keep_nulls: true,
295
})?;
296
let patterns = patterns.str()?;
297
let ac = build_ac(patterns, ascii_case_insensitive, leftmost)?;
298
let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);
299
300
for opt_val in ca.iter() {
301
if let Some(val) = opt_val {
302
push_idx(val, &mut builder, &ac, overlapping);
303
} else {
304
builder.append_null();
305
}
306
}
307
Ok(builder.finish())
308
},
309
(a, b) if a == b => {
310
let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);
311
let (ca, patterns) = align_chunks_binary(ca, patterns);
312
313
for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {
314
for z in arr.into_iter().zip(pat_arr.into_iter()) {
315
match z {
316
(None, _) | (_, None) => builder.append_null(),
317
(Some(val), Some(pat)) => {
318
let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
319
let ac = build_ac_arr(pat, ascii_case_insensitive, leftmost)?;
320
push_idx(val, &mut builder, &ac, overlapping);
321
},
322
}
323
}
324
}
325
Ok(builder.finish())
326
},
327
(a, b) => polars_bail!(length_mismatch = "str.find_many", a, b),
328
}
329
}
330
331