Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/strings/extract.rs
6939 views
1
use std::iter::zip;
2
3
#[cfg(feature = "extract_groups")]
4
use arrow::array::{Array, StructArray};
5
use arrow::array::{MutablePlString, Utf8ViewArray};
6
use polars_core::prelude::arity::{try_binary_mut_with_options, try_unary_mut_with_options};
7
use regex::Regex;
8
9
use super::*;
10
11
#[cfg(feature = "extract_groups")]
12
fn extract_groups_array(
13
arr: &Utf8ViewArray,
14
reg: &Regex,
15
names: &[&str],
16
dtype: ArrowDataType,
17
) -> PolarsResult<ArrayRef> {
18
let mut builders = (0..names.len())
19
.map(|_| MutablePlString::with_capacity(arr.len()))
20
.collect::<Vec<_>>();
21
22
let mut locs = reg.capture_locations();
23
for opt_v in arr {
24
if let Some(s) = opt_v {
25
if reg.captures_read(&mut locs, s).is_some() {
26
for (i, builder) in builders.iter_mut().enumerate() {
27
builder.push(locs.get(i + 1).map(|(start, stop)| &s[start..stop]));
28
}
29
continue;
30
}
31
}
32
33
// Push nulls if either the string is null or there was no match. We
34
// distinguish later between the two by copying arr's validity mask.
35
builders.iter_mut().for_each(|arr| arr.push_null());
36
}
37
38
let values = builders.into_iter().map(|a| a.freeze().boxed()).collect();
39
Ok(StructArray::new(dtype, arr.len(), values, arr.validity().cloned()).boxed())
40
}
41
42
#[cfg(feature = "extract_groups")]
43
pub(super) fn extract_groups(
44
ca: &StringChunked,
45
pat: &str,
46
dtype: &DataType,
47
) -> PolarsResult<Series> {
48
let reg = polars_utils::regex_cache::compile_regex(pat)?;
49
let n_fields = reg.captures_len();
50
if n_fields == 1 {
51
return StructChunked::from_series(ca.name().clone(), ca.len(), [].iter())
52
.map(|ca| ca.into_series());
53
}
54
55
let arrow_dtype = dtype.try_to_arrow(CompatLevel::newest())?;
56
let DataType::Struct(fields) = dtype else {
57
unreachable!() // Implementation error if it isn't a struct.
58
};
59
let names = fields
60
.iter()
61
.map(|fld| fld.name.as_str())
62
.collect::<Vec<_>>();
63
64
let chunks = ca
65
.downcast_iter()
66
.map(|array| extract_groups_array(array, &reg, &names, arrow_dtype.clone()))
67
.collect::<PolarsResult<Vec<_>>>()?;
68
69
Series::try_from((ca.name().clone(), chunks))
70
}
71
72
fn extract_group_reg_lit(
73
arr: &Utf8ViewArray,
74
reg: &Regex,
75
group_index: usize,
76
) -> PolarsResult<Utf8ViewArray> {
77
let mut builder = MutablePlString::with_capacity(arr.len());
78
79
let mut locs = reg.capture_locations();
80
for opt_v in arr {
81
if let Some(s) = opt_v {
82
if reg.captures_read(&mut locs, s).is_some() {
83
builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));
84
continue;
85
}
86
}
87
88
// Push null if either the string is null or there was no match.
89
builder.push_null();
90
}
91
92
Ok(builder.into())
93
}
94
95
fn extract_group_array_lit(
96
s: &str,
97
pat: &Utf8ViewArray,
98
group_index: usize,
99
) -> PolarsResult<Utf8ViewArray> {
100
let mut builder = MutablePlString::with_capacity(pat.len());
101
102
for opt_pat in pat {
103
if let Some(pat) = opt_pat {
104
let reg = polars_utils::regex_cache::compile_regex(pat)?;
105
let mut locs = reg.capture_locations();
106
if reg.captures_read(&mut locs, s).is_some() {
107
builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));
108
continue;
109
}
110
}
111
112
// Push null if either the pat is null or there was no match.
113
builder.push_null();
114
}
115
116
Ok(builder.into())
117
}
118
119
fn extract_group_binary(
120
arr: &Utf8ViewArray,
121
pat: &Utf8ViewArray,
122
group_index: usize,
123
) -> PolarsResult<Utf8ViewArray> {
124
let mut builder = MutablePlString::with_capacity(arr.len());
125
126
for (opt_s, opt_pat) in zip(arr, pat) {
127
match (opt_s, opt_pat) {
128
(Some(s), Some(pat)) => {
129
let reg = polars_utils::regex_cache::compile_regex(pat)?;
130
let mut locs = reg.capture_locations();
131
if reg.captures_read(&mut locs, s).is_some() {
132
builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));
133
continue;
134
}
135
// Push null if there was no match.
136
builder.push_null()
137
},
138
_ => builder.push_null(),
139
}
140
}
141
142
Ok(builder.into())
143
}
144
145
pub(super) fn extract_group(
146
ca: &StringChunked,
147
pat: &StringChunked,
148
group_index: usize,
149
) -> PolarsResult<StringChunked> {
150
match (ca.len(), pat.len()) {
151
(_, 1) => {
152
if let Some(pat) = pat.get(0) {
153
let reg = polars_utils::regex_cache::compile_regex(pat)?;
154
try_unary_mut_with_options(ca, |arr| extract_group_reg_lit(arr, &reg, group_index))
155
} else {
156
Ok(StringChunked::full_null(ca.name().clone(), ca.len()))
157
}
158
},
159
(1, _) => {
160
if let Some(s) = ca.get(0) {
161
try_unary_mut_with_options(pat, |pat| extract_group_array_lit(s, pat, group_index))
162
} else {
163
Ok(StringChunked::full_null(ca.name().clone(), pat.len()))
164
}
165
},
166
(len_ca, len_pat) if len_ca == len_pat => try_binary_mut_with_options(
167
ca,
168
pat,
169
|ca, pat| extract_group_binary(ca, pat, group_index),
170
ca.name().clone(),
171
),
172
_ => {
173
polars_bail!(ComputeError: "ca(len: {}) and pat(len: {}) should either broadcast or have the same length", ca.len(), pat.len())
174
},
175
}
176
}
177
178