Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/strings/find_many.rs
6939 views
1
use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
2
use arrow::array::Utf8ViewArray;
3
use polars_core::prelude::arity::unary_elementwise;
4
use polars_core::prelude::*;
5
use polars_core::utils::align_chunks_binary;
6
7
fn build_ac(patterns: &StringChunked, ascii_case_insensitive: bool) -> PolarsResult<AhoCorasick> {
8
AhoCorasickBuilder::new()
9
.ascii_case_insensitive(ascii_case_insensitive)
10
.build(patterns.downcast_iter().flatten().flatten())
11
.map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e))
12
}
13
14
fn build_ac_arr(
15
patterns: &Utf8ViewArray,
16
ascii_case_insensitive: bool,
17
) -> PolarsResult<AhoCorasick> {
18
AhoCorasickBuilder::new()
19
.ascii_case_insensitive(ascii_case_insensitive)
20
.build(patterns.into_iter().flatten())
21
.map_err(|e| polars_err!(ComputeError: "could not build aho corasick automaton {}", e))
22
}
23
24
pub fn contains_any(
25
ca: &StringChunked,
26
patterns: &ListChunked,
27
ascii_case_insensitive: bool,
28
) -> PolarsResult<BooleanChunked> {
29
polars_ensure!(
30
ca.len() == patterns.len() || ca.len() == 1 || patterns.len() == 1,
31
length_mismatch = "str.contains_any",
32
ca.len(),
33
patterns.len()
34
);
35
polars_ensure!(
36
patterns.len() == 1,
37
nyi = "`str.contains_any` with a pattern per row"
38
);
39
40
if patterns.has_nulls() {
41
return Ok(BooleanChunked::full_null(ca.name().clone(), ca.len()));
42
}
43
44
let patterns = patterns.explode(true)?;
45
let patterns = patterns.str()?;
46
let ac = build_ac(patterns, ascii_case_insensitive)?;
47
48
Ok(unary_elementwise(ca, |opt_val| {
49
opt_val.map(|val| ac.find(val).is_some())
50
}))
51
}
52
53
pub fn replace_all(
54
ca: &StringChunked,
55
patterns: &ListChunked,
56
replace_with: &ListChunked,
57
ascii_case_insensitive: bool,
58
) -> PolarsResult<StringChunked> {
59
let mut length = 1;
60
for (argument_idx, (argument, l)) in [
61
("self", ca.len()),
62
("patterns", patterns.len()),
63
("replace_with", replace_with.len()),
64
]
65
.into_iter()
66
.enumerate()
67
{
68
if l != 1 {
69
if l != length && length != 1 {
70
polars_bail!(
71
length_mismatch = "str.replace_many",
72
l,
73
length,
74
argument = argument,
75
argument_idx = argument_idx
76
);
77
}
78
length = l;
79
}
80
}
81
82
polars_ensure!(
83
patterns.len() == 1 && replace_with.len() == 1,
84
nyi = "`str.replace_many` with a pattern per row"
85
);
86
87
if patterns.has_nulls() || replace_with.has_nulls() {
88
return Ok(StringChunked::full_null(ca.name().clone(), ca.len()));
89
}
90
91
let patterns = patterns.explode(true)?;
92
let patterns = patterns.str()?;
93
let replace_with = replace_with.explode(true)?;
94
let replace_with = replace_with.str()?;
95
96
let replace_with = if replace_with.len() == 1 && patterns.len() > 1 {
97
replace_with.new_from_index(0, patterns.len())
98
} else {
99
replace_with.clone()
100
};
101
102
polars_ensure!(patterns.len() == replace_with.len(), InvalidOperation: "expected the same amount of patterns as replacement strings");
103
polars_ensure!(patterns.null_count() == 0 && replace_with.null_count() == 0, InvalidOperation: "'patterns'/'replace_with' should not have nulls");
104
let replace_with = replace_with
105
.downcast_iter()
106
.flatten()
107
.flatten()
108
.collect::<Vec<_>>();
109
110
let ac = build_ac(patterns, ascii_case_insensitive)?;
111
112
Ok(unary_elementwise(ca, |opt_val| {
113
opt_val.map(|val| ac.replace_all(val, replace_with.as_slice()))
114
}))
115
}
116
117
fn push_str(
118
val: &str,
119
builder: &mut ListStringChunkedBuilder,
120
ac: &AhoCorasick,
121
overlapping: bool,
122
) {
123
if overlapping {
124
let iter = ac.find_overlapping_iter(val);
125
let iter = iter.map(|m| &val[m.start()..m.end()]);
126
builder.append_values_iter(iter);
127
} else {
128
let iter = ac.find_iter(val);
129
let iter = iter.map(|m| &val[m.start()..m.end()]);
130
builder.append_values_iter(iter);
131
}
132
}
133
134
pub fn extract_many(
135
ca: &StringChunked,
136
patterns: &ListChunked,
137
ascii_case_insensitive: bool,
138
overlapping: bool,
139
) -> PolarsResult<ListChunked> {
140
match (ca.len(), patterns.len()) {
141
(1, _) => match ca.get(0) {
142
None => Ok(ListChunked::full_null_with_dtype(
143
ca.name().clone(),
144
ca.len(),
145
&DataType::String,
146
)),
147
Some(val) => {
148
let mut builder =
149
ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);
150
151
for pat in patterns.amortized_iter() {
152
match pat {
153
None => builder.append_null(),
154
Some(pat) => {
155
let pat = pat.as_ref();
156
let pat = pat.str()?;
157
let pat = pat.rechunk();
158
let pat = pat.downcast_as_array();
159
let ac = build_ac_arr(pat, ascii_case_insensitive)?;
160
push_str(val, &mut builder, &ac, overlapping);
161
},
162
}
163
}
164
Ok(builder.finish())
165
},
166
},
167
(_, 1) => {
168
let patterns = patterns.explode(true)?;
169
let patterns = patterns.str()?;
170
let ac = build_ac(patterns, ascii_case_insensitive)?;
171
let mut builder =
172
ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);
173
174
for arr in ca.downcast_iter() {
175
for opt_val in arr.into_iter() {
176
if let Some(val) = opt_val {
177
push_str(val, &mut builder, &ac, overlapping);
178
} else {
179
builder.append_null();
180
}
181
}
182
}
183
Ok(builder.finish())
184
},
185
(a, b) if a == b => {
186
let mut builder =
187
ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2);
188
let (ca, patterns) = align_chunks_binary(ca, patterns);
189
190
for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {
191
for z in arr.into_iter().zip(pat_arr.into_iter()) {
192
match z {
193
(None, _) | (_, None) => builder.append_null(),
194
(Some(val), Some(pat)) => {
195
let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
196
let ac = build_ac_arr(pat, ascii_case_insensitive)?;
197
push_str(val, &mut builder, &ac, overlapping);
198
},
199
}
200
}
201
}
202
Ok(builder.finish())
203
},
204
(a, b) => polars_bail!(length_mismatch = "str.extract_many", a, b),
205
}
206
}
207
208
type B = ListPrimitiveChunkedBuilder<UInt32Type>;
209
fn push_idx(val: &str, builder: &mut B, ac: &AhoCorasick, overlapping: bool) {
210
if overlapping {
211
let iter = ac.find_overlapping_iter(val);
212
let iter = iter.map(|m| m.start() as u32);
213
builder.append_values_iter(iter);
214
} else {
215
let iter = ac.find_iter(val);
216
let iter = iter.map(|m| m.start() as u32);
217
builder.append_values_iter(iter);
218
}
219
}
220
221
pub fn find_many(
222
ca: &StringChunked,
223
patterns: &ListChunked,
224
ascii_case_insensitive: bool,
225
overlapping: bool,
226
) -> PolarsResult<ListChunked> {
227
type B = ListPrimitiveChunkedBuilder<UInt32Type>;
228
match (ca.len(), patterns.len()) {
229
(1, _) => match ca.get(0) {
230
None => Ok(ListChunked::full_null_with_dtype(
231
ca.name().clone(),
232
patterns.len(),
233
&DataType::UInt32,
234
)),
235
Some(val) => {
236
let mut builder = B::new(
237
ca.name().clone(),
238
patterns.len(),
239
patterns.len() * 2,
240
DataType::UInt32,
241
);
242
for pat in patterns.amortized_iter() {
243
match pat {
244
None => builder.append_null(),
245
Some(pat) => {
246
let pat = pat.as_ref();
247
let pat = pat.str()?;
248
let pat = pat.rechunk();
249
let pat = pat.downcast_as_array();
250
let ac = build_ac_arr(pat, ascii_case_insensitive)?;
251
push_idx(val, &mut builder, &ac, overlapping);
252
},
253
}
254
}
255
Ok(builder.finish())
256
},
257
},
258
(_, 1) => {
259
let patterns = patterns.explode(true)?;
260
let patterns = patterns.str()?;
261
let ac = build_ac(patterns, ascii_case_insensitive)?;
262
let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);
263
264
for opt_val in ca.iter() {
265
if let Some(val) = opt_val {
266
push_idx(val, &mut builder, &ac, overlapping);
267
} else {
268
builder.append_null();
269
}
270
}
271
Ok(builder.finish())
272
},
273
(a, b) if a == b => {
274
let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32);
275
let (ca, patterns) = align_chunks_binary(ca, patterns);
276
277
for (arr, pat_arr) in ca.downcast_iter().zip(patterns.downcast_iter()) {
278
for z in arr.into_iter().zip(pat_arr.into_iter()) {
279
match z {
280
(None, _) | (_, None) => builder.append_null(),
281
(Some(val), Some(pat)) => {
282
let pat = pat.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
283
let ac = build_ac_arr(pat, ascii_case_insensitive)?;
284
push_idx(val, &mut builder, &ac, overlapping);
285
},
286
}
287
}
288
}
289
Ok(builder.finish())
290
},
291
(a, b) => polars_bail!(length_mismatch = "str.find_many", a, b),
292
}
293
}
294
295