Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/strings.rs
7889 views
1
#[cfg(feature = "dtype-decimal")]
2
use polars_compute::decimal::DEC128_MAX_PREC;
3
#[cfg(feature = "dtype-struct")]
4
use polars_utils::format_pl_smallstr;
5
6
use super::*;
7
8
#[cfg(all(feature = "regex", feature = "timezones"))]
9
polars_utils::regex_cache::cached_regex! {
10
pub static TZ_AWARE_RE = r"(%z)|(%:z)|(%::z)|(%:::z)|(%#z)|(^%\+$)";
11
}
12
13
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
14
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
15
pub enum IRStringFunction {
16
Format {
17
format: PlSmallStr,
18
insertions: Arc<[usize]>,
19
},
20
#[cfg(feature = "concat_str")]
21
ConcatHorizontal {
22
delimiter: PlSmallStr,
23
ignore_nulls: bool,
24
},
25
#[cfg(feature = "concat_str")]
26
ConcatVertical {
27
delimiter: PlSmallStr,
28
ignore_nulls: bool,
29
},
30
#[cfg(feature = "regex")]
31
Contains {
32
literal: bool,
33
strict: bool,
34
},
35
CountMatches(bool),
36
EndsWith,
37
Extract(usize),
38
ExtractAll,
39
#[cfg(feature = "extract_groups")]
40
ExtractGroups {
41
dtype: DataType,
42
pat: PlSmallStr,
43
},
44
#[cfg(feature = "regex")]
45
Find {
46
literal: bool,
47
strict: bool,
48
},
49
#[cfg(feature = "string_to_integer")]
50
ToInteger {
51
dtype: Option<DataType>,
52
strict: bool,
53
},
54
LenBytes,
55
LenChars,
56
Lowercase,
57
#[cfg(feature = "extract_jsonpath")]
58
JsonDecode(DataType),
59
#[cfg(feature = "extract_jsonpath")]
60
JsonPathMatch,
61
#[cfg(feature = "regex")]
62
Replace {
63
// negative is replace all
64
// how many matches to replace
65
n: i64,
66
literal: bool,
67
},
68
#[cfg(feature = "string_normalize")]
69
Normalize {
70
form: UnicodeForm,
71
},
72
#[cfg(feature = "string_reverse")]
73
Reverse,
74
#[cfg(feature = "string_pad")]
75
PadStart {
76
fill_char: char,
77
},
78
#[cfg(feature = "string_pad")]
79
PadEnd {
80
fill_char: char,
81
},
82
Slice,
83
Head,
84
Tail,
85
#[cfg(feature = "string_encoding")]
86
HexEncode,
87
#[cfg(feature = "binary_encoding")]
88
HexDecode(bool),
89
#[cfg(feature = "string_encoding")]
90
Base64Encode,
91
#[cfg(feature = "binary_encoding")]
92
Base64Decode(bool),
93
StartsWith,
94
StripChars,
95
StripCharsStart,
96
StripCharsEnd,
97
StripPrefix,
98
StripSuffix,
99
#[cfg(feature = "dtype-struct")]
100
SplitExact {
101
n: usize,
102
inclusive: bool,
103
},
104
#[cfg(feature = "dtype-struct")]
105
SplitN(usize),
106
#[cfg(feature = "temporal")]
107
// DataType can only be Date/Datetime/Time
108
Strptime(DataType, StrptimeOptions),
109
Split(bool),
110
#[cfg(feature = "dtype-decimal")]
111
ToDecimal {
112
scale: usize,
113
},
114
#[cfg(feature = "nightly")]
115
Titlecase,
116
Uppercase,
117
#[cfg(feature = "string_pad")]
118
ZFill,
119
#[cfg(feature = "find_many")]
120
ContainsAny {
121
ascii_case_insensitive: bool,
122
},
123
#[cfg(feature = "find_many")]
124
ReplaceMany {
125
ascii_case_insensitive: bool,
126
leftmost: bool,
127
},
128
#[cfg(feature = "find_many")]
129
ExtractMany {
130
ascii_case_insensitive: bool,
131
overlapping: bool,
132
leftmost: bool,
133
},
134
#[cfg(feature = "find_many")]
135
FindMany {
136
ascii_case_insensitive: bool,
137
overlapping: bool,
138
leftmost: bool,
139
},
140
#[cfg(feature = "regex")]
141
EscapeRegex,
142
}
143
144
impl IRStringFunction {
145
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
146
use IRStringFunction::*;
147
match self {
148
Format { .. } => mapper.with_dtype(DataType::String),
149
#[cfg(feature = "concat_str")]
150
ConcatVertical { .. } | ConcatHorizontal { .. } => mapper.with_dtype(DataType::String),
151
#[cfg(feature = "regex")]
152
Contains { .. } => mapper.with_dtype(DataType::Boolean),
153
CountMatches(_) => mapper.with_dtype(DataType::UInt32),
154
EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
155
Extract(_) => mapper.with_same_dtype(),
156
ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
157
#[cfg(feature = "extract_groups")]
158
ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()),
159
#[cfg(feature = "string_to_integer")]
160
ToInteger { dtype, .. } => mapper.with_dtype(dtype.clone().unwrap_or(DataType::Int64)),
161
#[cfg(feature = "regex")]
162
Find { .. } => mapper.with_dtype(DataType::UInt32),
163
#[cfg(feature = "extract_jsonpath")]
164
JsonDecode(dtype) => mapper.with_dtype(dtype.clone()),
165
#[cfg(feature = "extract_jsonpath")]
166
JsonPathMatch => mapper.with_dtype(DataType::String),
167
LenBytes => mapper.with_dtype(DataType::UInt32),
168
LenChars => mapper.with_dtype(DataType::UInt32),
169
#[cfg(feature = "regex")]
170
Replace { .. } => mapper.with_same_dtype(),
171
#[cfg(feature = "string_normalize")]
172
Normalize { .. } => mapper.with_same_dtype(),
173
#[cfg(feature = "string_reverse")]
174
Reverse => mapper.with_same_dtype(),
175
#[cfg(feature = "temporal")]
176
Strptime(dtype, options) => match dtype {
177
#[cfg(feature = "dtype-datetime")]
178
DataType::Datetime(time_unit, time_zone) => {
179
let mut time_zone = time_zone.clone();
180
#[cfg(all(feature = "regex", feature = "timezones"))]
181
if options
182
.format
183
.as_ref()
184
.is_some_and(|format| TZ_AWARE_RE.is_match(format.as_str()))
185
&& time_zone.is_none()
186
{
187
time_zone = Some(time_zone.unwrap_or(TimeZone::UTC));
188
}
189
mapper.with_dtype(DataType::Datetime(*time_unit, time_zone))
190
},
191
_ => mapper.with_dtype(dtype.clone()),
192
},
193
Split(_) => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
194
#[cfg(feature = "nightly")]
195
Titlecase => mapper.with_same_dtype(),
196
#[cfg(feature = "dtype-decimal")]
197
ToDecimal { scale } => mapper.with_dtype(DataType::Decimal(DEC128_MAX_PREC, *scale)),
198
#[cfg(feature = "string_encoding")]
199
HexEncode => mapper.with_same_dtype(),
200
#[cfg(feature = "binary_encoding")]
201
HexDecode(_) => mapper.with_dtype(DataType::Binary),
202
#[cfg(feature = "string_encoding")]
203
Base64Encode => mapper.with_same_dtype(),
204
#[cfg(feature = "binary_encoding")]
205
Base64Decode(_) => mapper.with_dtype(DataType::Binary),
206
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
207
| StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(),
208
#[cfg(feature = "string_pad")]
209
PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(),
210
#[cfg(feature = "dtype-struct")]
211
SplitExact { n, .. } => mapper.with_dtype(DataType::Struct(
212
(0..n + 1)
213
.map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String))
214
.collect(),
215
)),
216
#[cfg(feature = "dtype-struct")]
217
SplitN(n) => mapper.with_dtype(DataType::Struct(
218
(0..*n)
219
.map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String))
220
.collect(),
221
)),
222
#[cfg(feature = "find_many")]
223
ContainsAny { .. } => mapper.with_dtype(DataType::Boolean),
224
#[cfg(feature = "find_many")]
225
ReplaceMany { .. } => mapper.with_same_dtype(),
226
#[cfg(feature = "find_many")]
227
ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
228
#[cfg(feature = "find_many")]
229
FindMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::UInt32))),
230
#[cfg(feature = "regex")]
231
EscapeRegex => mapper.with_same_dtype(),
232
}
233
}
234
235
pub fn function_options(&self) -> FunctionOptions {
236
use IRStringFunction as S;
237
match self {
238
S::Format { .. } => FunctionOptions::elementwise(),
239
#[cfg(feature = "concat_str")]
240
S::ConcatHorizontal { .. } => FunctionOptions::elementwise()
241
.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),
242
#[cfg(feature = "concat_str")]
243
S::ConcatVertical { .. } => FunctionOptions::aggregation(),
244
#[cfg(feature = "regex")]
245
S::Contains { .. } => {
246
FunctionOptions::elementwise().with_supertyping(Default::default())
247
},
248
S::CountMatches(_) => FunctionOptions::elementwise(),
249
S::EndsWith | S::StartsWith | S::Extract(_) => {
250
FunctionOptions::elementwise().with_supertyping(Default::default())
251
},
252
S::ExtractAll => FunctionOptions::elementwise(),
253
#[cfg(feature = "extract_groups")]
254
S::ExtractGroups { .. } => FunctionOptions::elementwise(),
255
#[cfg(feature = "string_to_integer")]
256
S::ToInteger { .. } => FunctionOptions::elementwise(),
257
#[cfg(feature = "regex")]
258
S::Find { .. } => FunctionOptions::elementwise().with_supertyping(Default::default()),
259
#[cfg(feature = "extract_jsonpath")]
260
S::JsonDecode { .. } => FunctionOptions::elementwise(),
261
#[cfg(feature = "extract_jsonpath")]
262
S::JsonPathMatch => FunctionOptions::elementwise(),
263
S::LenBytes | S::LenChars => FunctionOptions::elementwise(),
264
#[cfg(feature = "regex")]
265
S::Replace { .. } => {
266
FunctionOptions::elementwise().with_supertyping(Default::default())
267
},
268
#[cfg(feature = "string_normalize")]
269
S::Normalize { .. } => FunctionOptions::elementwise(),
270
#[cfg(feature = "string_reverse")]
271
S::Reverse => FunctionOptions::elementwise(),
272
#[cfg(feature = "temporal")]
273
S::Strptime(_, options) if options.format.is_some() => FunctionOptions::elementwise(),
274
#[cfg(feature = "temporal")]
275
S::Strptime(_, _) => FunctionOptions::elementwise_with_infer(),
276
S::Split(_) => FunctionOptions::elementwise(),
277
#[cfg(feature = "nightly")]
278
S::Titlecase => FunctionOptions::elementwise(),
279
#[cfg(feature = "dtype-decimal")]
280
S::ToDecimal { .. } => FunctionOptions::elementwise(),
281
#[cfg(feature = "string_encoding")]
282
S::HexEncode | S::Base64Encode => FunctionOptions::elementwise(),
283
#[cfg(feature = "binary_encoding")]
284
S::HexDecode(_) | S::Base64Decode(_) => FunctionOptions::elementwise(),
285
S::Uppercase | S::Lowercase => FunctionOptions::elementwise(),
286
S::StripChars
287
| S::StripCharsStart
288
| S::StripCharsEnd
289
| S::StripPrefix
290
| S::StripSuffix
291
| S::Head
292
| S::Tail => FunctionOptions::elementwise(),
293
S::Slice => FunctionOptions::elementwise(),
294
#[cfg(feature = "string_pad")]
295
S::PadStart { .. } | S::PadEnd { .. } | S::ZFill => FunctionOptions::elementwise(),
296
#[cfg(feature = "dtype-struct")]
297
S::SplitExact { .. } => FunctionOptions::elementwise(),
298
#[cfg(feature = "dtype-struct")]
299
S::SplitN(_) => FunctionOptions::elementwise(),
300
#[cfg(feature = "find_many")]
301
S::ContainsAny { .. } => FunctionOptions::elementwise(),
302
#[cfg(feature = "find_many")]
303
S::ReplaceMany { .. } => FunctionOptions::elementwise(),
304
#[cfg(feature = "find_many")]
305
S::ExtractMany { .. } => FunctionOptions::elementwise(),
306
#[cfg(feature = "find_many")]
307
S::FindMany { .. } => FunctionOptions::elementwise(),
308
#[cfg(feature = "regex")]
309
S::EscapeRegex => FunctionOptions::elementwise(),
310
}
311
}
312
}
313
314
impl Display for IRStringFunction {
315
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
316
use IRStringFunction::*;
317
let s = match self {
318
Format { .. } => "format",
319
#[cfg(feature = "regex")]
320
Contains { .. } => "contains",
321
CountMatches(_) => "count_matches",
322
EndsWith => "ends_with",
323
Extract(_) => "extract",
324
#[cfg(feature = "concat_str")]
325
ConcatHorizontal { .. } => "concat_horizontal",
326
#[cfg(feature = "concat_str")]
327
ConcatVertical { .. } => "concat_vertical",
328
ExtractAll => "extract_all",
329
#[cfg(feature = "extract_groups")]
330
ExtractGroups { .. } => "extract_groups",
331
#[cfg(feature = "string_to_integer")]
332
ToInteger { .. } => "to_integer",
333
#[cfg(feature = "regex")]
334
Find { .. } => "find",
335
Head => "head",
336
Tail => "tail",
337
#[cfg(feature = "extract_jsonpath")]
338
JsonDecode(..) => "json_decode",
339
#[cfg(feature = "extract_jsonpath")]
340
JsonPathMatch => "json_path_match",
341
LenBytes => "len_bytes",
342
Lowercase => "to_lowercase",
343
LenChars => "len_chars",
344
#[cfg(feature = "string_pad")]
345
PadEnd { .. } => "pad_end",
346
#[cfg(feature = "string_pad")]
347
PadStart { .. } => "pad_start",
348
#[cfg(feature = "regex")]
349
Replace { .. } => "replace",
350
#[cfg(feature = "string_normalize")]
351
Normalize { .. } => "normalize",
352
#[cfg(feature = "string_reverse")]
353
Reverse => "reverse",
354
#[cfg(feature = "string_encoding")]
355
HexEncode => "hex_encode",
356
#[cfg(feature = "binary_encoding")]
357
HexDecode(_) => "hex_decode",
358
#[cfg(feature = "string_encoding")]
359
Base64Encode => "base64_encode",
360
#[cfg(feature = "binary_encoding")]
361
Base64Decode(_) => "base64_decode",
362
Slice => "slice",
363
StartsWith => "starts_with",
364
StripChars => "strip_chars",
365
StripCharsStart => "strip_chars_start",
366
StripCharsEnd => "strip_chars_end",
367
StripPrefix => "strip_prefix",
368
StripSuffix => "strip_suffix",
369
#[cfg(feature = "dtype-struct")]
370
SplitExact { inclusive, .. } => {
371
if *inclusive {
372
"split_exact_inclusive"
373
} else {
374
"split_exact"
375
}
376
},
377
#[cfg(feature = "dtype-struct")]
378
SplitN(_) => "splitn",
379
#[cfg(feature = "temporal")]
380
Strptime(_, _) => "strptime",
381
Split(inclusive) => {
382
if *inclusive {
383
"split_inclusive"
384
} else {
385
"split"
386
}
387
},
388
#[cfg(feature = "nightly")]
389
Titlecase => "to_titlecase",
390
#[cfg(feature = "dtype-decimal")]
391
ToDecimal { .. } => "to_decimal",
392
Uppercase => "to_uppercase",
393
#[cfg(feature = "string_pad")]
394
ZFill => "zfill",
395
#[cfg(feature = "find_many")]
396
ContainsAny { .. } => "contains_any",
397
#[cfg(feature = "find_many")]
398
ReplaceMany { .. } => "replace_many",
399
#[cfg(feature = "find_many")]
400
ExtractMany { .. } => "extract_many",
401
#[cfg(feature = "find_many")]
402
FindMany { .. } => "extract_many",
403
#[cfg(feature = "regex")]
404
EscapeRegex => "escape_regex",
405
};
406
write!(f, "str.{s}")
407
}
408
}
409
410
impl From<IRStringFunction> for IRFunctionExpr {
411
fn from(str: IRStringFunction) -> Self {
412
IRFunctionExpr::StringExpr(str)
413
}
414
}
415
416