Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs
7889 views
1
use arrow::legacy::error::PolarsResult;
2
use polars_utils::arena::Node;
3
use polars_utils::format_pl_smallstr;
4
use polars_utils::option::OptionTry;
5
6
use super::expr_to_ir::ExprToIRContext;
7
use super::*;
8
use crate::constants::get_literal_name;
9
use crate::dsl::{Expr, FunctionExpr};
10
use crate::plans::conversion::dsl_to_ir::expr_to_ir::{to_expr_ir, to_expr_irs};
11
use crate::plans::{AExpr, IRFunctionExpr};
12
13
pub(super) fn convert_functions(
14
input: Vec<Expr>,
15
function: FunctionExpr,
16
ctx: &mut ExprToIRContext,
17
) -> PolarsResult<(Node, PlSmallStr)> {
18
use {FunctionExpr as F, IRFunctionExpr as I};
19
20
#[cfg(feature = "dtype-struct")]
21
if matches!(
22
function,
23
FunctionExpr::StructExpr(StructFunction::WithFields)
24
) {
25
let mut input = input.into_iter();
26
let struct_input = to_expr_ir(input.next().unwrap(), ctx)?;
27
let dtype = struct_input.to_expr(ctx.arena).to_field(ctx.schema)?.dtype;
28
let DataType::Struct(fields) = &dtype else {
29
polars_bail!(op = "struct.with_fields", dtype);
30
};
31
32
let struct_name = struct_input.output_name().clone();
33
let struct_node = struct_input.node();
34
let struct_schema = Schema::from_iter(fields.iter().cloned());
35
36
let mut e = Vec::with_capacity(input.len());
37
e.push(struct_input);
38
39
let prev = ctx.with_fields.replace((struct_node, struct_schema));
40
for i in input {
41
e.push(to_expr_ir(i, ctx)?);
42
}
43
ctx.with_fields = prev;
44
45
let function = IRFunctionExpr::StructExpr(IRStructFunction::WithFields);
46
let options = function.function_options();
47
let out = ctx.arena.add(AExpr::Function {
48
input: e,
49
function,
50
options,
51
});
52
53
return Ok((out, struct_name));
54
}
55
56
let input_is_empty = input.is_empty();
57
58
// Converts inputs
59
let e = to_expr_irs(input, ctx)?;
60
let mut set_elementwise = false;
61
62
// Return before converting inputs
63
let ir_function = match function {
64
#[cfg(feature = "dtype-array")]
65
F::ArrayExpr(array_function) => {
66
use {ArrayFunction as A, IRArrayFunction as IA};
67
I::ArrayExpr(match array_function {
68
A::Length => IA::Length,
69
A::Min => IA::Min,
70
A::Max => IA::Max,
71
A::Sum => IA::Sum,
72
A::ToList => IA::ToList,
73
A::Unique(stable) => IA::Unique(stable),
74
A::NUnique => IA::NUnique,
75
A::Std(v) => IA::Std(v),
76
A::Var(v) => IA::Var(v),
77
A::Mean => IA::Mean,
78
A::Median => IA::Median,
79
#[cfg(feature = "array_any_all")]
80
A::Any => IA::Any,
81
#[cfg(feature = "array_any_all")]
82
A::All => IA::All,
83
A::Sort(sort_options) => IA::Sort(sort_options),
84
A::Reverse => IA::Reverse,
85
A::ArgMin => IA::ArgMin,
86
A::ArgMax => IA::ArgMax,
87
A::Get(v) => IA::Get(v),
88
A::Join(v) => IA::Join(v),
89
#[cfg(feature = "is_in")]
90
A::Contains { nulls_equal } => IA::Contains { nulls_equal },
91
#[cfg(feature = "array_count")]
92
A::CountMatches => IA::CountMatches,
93
A::Shift => IA::Shift,
94
A::Explode(options) => IA::Explode(options),
95
A::Concat => IA::Concat,
96
A::Slice(offset, length) => IA::Slice(offset, length),
97
#[cfg(feature = "array_to_struct")]
98
A::ToStruct(ng) => IA::ToStruct(ng),
99
})
100
},
101
F::BinaryExpr(binary_function) => {
102
use {BinaryFunction as B, IRBinaryFunction as IB};
103
I::BinaryExpr(match binary_function {
104
B::Contains => IB::Contains,
105
B::StartsWith => IB::StartsWith,
106
B::EndsWith => IB::EndsWith,
107
#[cfg(feature = "binary_encoding")]
108
B::HexDecode(v) => IB::HexDecode(v),
109
#[cfg(feature = "binary_encoding")]
110
B::HexEncode => IB::HexEncode,
111
#[cfg(feature = "binary_encoding")]
112
B::Base64Decode(v) => IB::Base64Decode(v),
113
#[cfg(feature = "binary_encoding")]
114
B::Base64Encode => IB::Base64Encode,
115
B::Size => IB::Size,
116
#[cfg(feature = "binary_encoding")]
117
B::Reinterpret(dtype_expr, v) => {
118
let dtype = dtype_expr.into_datatype(ctx.schema)?;
119
let can_reinterpret_to =
120
|dt: &DataType| dt.is_primitive_numeric() || dt.is_temporal();
121
polars_ensure!(
122
can_reinterpret_to(&dtype) || (
123
dtype.is_array() && dtype.inner_dtype().map(can_reinterpret_to) == Some(true)
124
),
125
InvalidOperation:
126
"cannot reinterpret binary to dtype {:?}. Only numeric or temporal dtype, or Arrays of these, are supported. Hint: To reinterpret to a nested Array, first reinterpret to a linear Array, and then use reshape",
127
dtype
128
);
129
IB::Reinterpret(dtype, v)
130
},
131
B::Slice => IB::Slice,
132
B::Head => IB::Head,
133
B::Tail => IB::Tail,
134
})
135
},
136
#[cfg(feature = "dtype-categorical")]
137
F::Categorical(categorical_function) => {
138
use {CategoricalFunction as C, IRCategoricalFunction as IC};
139
I::Categorical(match categorical_function {
140
C::GetCategories => IC::GetCategories,
141
#[cfg(feature = "strings")]
142
C::LenBytes => IC::LenBytes,
143
#[cfg(feature = "strings")]
144
C::LenChars => IC::LenChars,
145
#[cfg(feature = "strings")]
146
C::StartsWith(v) => IC::StartsWith(v),
147
#[cfg(feature = "strings")]
148
C::EndsWith(v) => IC::EndsWith(v),
149
#[cfg(feature = "strings")]
150
C::Slice(s, e) => IC::Slice(s, e),
151
})
152
},
153
#[cfg(feature = "dtype-extension")]
154
F::Extension(extension_function) => {
155
use {ExtensionFunction as E, IRExtensionFunction as IE};
156
I::Extension(match extension_function {
157
E::To(dtype) => {
158
let concrete_dtype = dtype.into_datatype(ctx.schema)?;
159
polars_ensure!(matches!(concrete_dtype, DataType::Extension(_, _)),
160
InvalidOperation: "ext.to() requires an Extension dtype, got {concrete_dtype:?}"
161
);
162
IE::To(concrete_dtype)
163
},
164
E::Storage => IE::Storage,
165
})
166
},
167
F::ListExpr(list_function) => {
168
use {IRListFunction as IL, ListFunction as L};
169
I::ListExpr(match list_function {
170
L::Concat => IL::Concat,
171
#[cfg(feature = "is_in")]
172
L::Contains { nulls_equal } => IL::Contains { nulls_equal },
173
#[cfg(feature = "list_drop_nulls")]
174
L::DropNulls => IL::DropNulls,
175
#[cfg(feature = "list_sample")]
176
L::Sample {
177
is_fraction,
178
with_replacement,
179
shuffle,
180
seed,
181
} => IL::Sample {
182
is_fraction,
183
with_replacement,
184
shuffle,
185
seed,
186
},
187
L::Slice => IL::Slice,
188
L::Shift => IL::Shift,
189
L::Get(v) => IL::Get(v),
190
#[cfg(feature = "list_gather")]
191
L::Gather(v) => IL::Gather(v),
192
#[cfg(feature = "list_gather")]
193
L::GatherEvery => IL::GatherEvery,
194
#[cfg(feature = "list_count")]
195
L::CountMatches => IL::CountMatches,
196
L::Sum => IL::Sum,
197
L::Length => IL::Length,
198
L::Max => IL::Max,
199
L::Min => IL::Min,
200
L::Mean => IL::Mean,
201
L::Median => IL::Median,
202
L::Std(v) => IL::Std(v),
203
L::Var(v) => IL::Var(v),
204
L::ArgMin => IL::ArgMin,
205
L::ArgMax => IL::ArgMax,
206
#[cfg(feature = "diff")]
207
L::Diff { n, null_behavior } => IL::Diff { n, null_behavior },
208
L::Sort(sort_options) => IL::Sort(sort_options),
209
L::Reverse => IL::Reverse,
210
L::Unique(v) => IL::Unique(v),
211
L::NUnique => IL::NUnique,
212
#[cfg(feature = "list_sets")]
213
L::SetOperation(set_operation) => IL::SetOperation(set_operation),
214
#[cfg(feature = "list_any_all")]
215
L::Any => IL::Any,
216
#[cfg(feature = "list_any_all")]
217
L::All => IL::All,
218
L::Join(v) => IL::Join(v),
219
#[cfg(feature = "dtype-array")]
220
L::ToArray(v) => IL::ToArray(v),
221
#[cfg(feature = "list_to_struct")]
222
L::ToStruct(list_to_struct_args) => IL::ToStruct(list_to_struct_args),
223
})
224
},
225
#[cfg(feature = "strings")]
226
F::StringExpr(string_function) => {
227
use {IRStringFunction as IS, StringFunction as S};
228
I::StringExpr(match string_function {
229
S::Format { format, insertions } => {
230
if input_is_empty {
231
polars_ensure!(
232
insertions.is_empty(),
233
ComputeError: "StringFormat didn't get any inputs, format: \"{}\"",
234
format
235
);
236
237
let out = ctx
238
.arena
239
.add(AExpr::Literal(LiteralValue::Scalar(Scalar::from(format))));
240
241
return Ok((out, get_literal_name()));
242
} else {
243
IS::Format { format, insertions }
244
}
245
},
246
#[cfg(feature = "concat_str")]
247
S::ConcatHorizontal {
248
delimiter,
249
ignore_nulls,
250
} => IS::ConcatHorizontal {
251
delimiter,
252
ignore_nulls,
253
},
254
#[cfg(feature = "concat_str")]
255
S::ConcatVertical {
256
delimiter,
257
ignore_nulls,
258
} => IS::ConcatVertical {
259
delimiter,
260
ignore_nulls,
261
},
262
#[cfg(feature = "regex")]
263
S::Contains { literal, strict } => IS::Contains { literal, strict },
264
S::CountMatches(v) => IS::CountMatches(v),
265
S::EndsWith => IS::EndsWith,
266
S::Extract(v) => IS::Extract(v),
267
S::ExtractAll => IS::ExtractAll,
268
#[cfg(feature = "extract_groups")]
269
S::ExtractGroups { dtype, pat } => IS::ExtractGroups { dtype, pat },
270
#[cfg(feature = "regex")]
271
S::Find { literal, strict } => IS::Find { literal, strict },
272
#[cfg(feature = "string_to_integer")]
273
S::ToInteger { dtype, strict } => IS::ToInteger { dtype, strict },
274
S::LenBytes => IS::LenBytes,
275
S::LenChars => IS::LenChars,
276
S::Lowercase => IS::Lowercase,
277
#[cfg(feature = "extract_jsonpath")]
278
S::JsonDecode(dtype) => IS::JsonDecode(dtype.into_datatype(ctx.schema)?),
279
#[cfg(feature = "extract_jsonpath")]
280
S::JsonPathMatch => IS::JsonPathMatch,
281
#[cfg(feature = "regex")]
282
S::Replace { n, literal } => IS::Replace { n, literal },
283
#[cfg(feature = "string_normalize")]
284
S::Normalize { form } => IS::Normalize { form },
285
#[cfg(feature = "string_reverse")]
286
S::Reverse => IS::Reverse,
287
#[cfg(feature = "string_pad")]
288
S::PadStart { fill_char } => IS::PadStart { fill_char },
289
#[cfg(feature = "string_pad")]
290
S::PadEnd { fill_char } => IS::PadEnd { fill_char },
291
S::Slice => IS::Slice,
292
S::Head => IS::Head,
293
S::Tail => IS::Tail,
294
#[cfg(feature = "string_encoding")]
295
S::HexEncode => IS::HexEncode,
296
#[cfg(feature = "binary_encoding")]
297
S::HexDecode(v) => IS::HexDecode(v),
298
#[cfg(feature = "string_encoding")]
299
S::Base64Encode => IS::Base64Encode,
300
#[cfg(feature = "binary_encoding")]
301
S::Base64Decode(v) => IS::Base64Decode(v),
302
S::StartsWith => IS::StartsWith,
303
S::StripChars => IS::StripChars,
304
S::StripCharsStart => IS::StripCharsStart,
305
S::StripCharsEnd => IS::StripCharsEnd,
306
S::StripPrefix => IS::StripPrefix,
307
S::StripSuffix => IS::StripSuffix,
308
#[cfg(feature = "dtype-struct")]
309
S::SplitExact { n, inclusive } => IS::SplitExact { n, inclusive },
310
#[cfg(feature = "dtype-struct")]
311
S::SplitN(v) => IS::SplitN(v),
312
#[cfg(feature = "temporal")]
313
S::Strptime(data_type, strptime_options) => {
314
let is_column_independent = is_column_independent_aexpr(e[0].node(), ctx.arena);
315
set_elementwise = is_column_independent;
316
let dtype = data_type.into_datatype(ctx.schema)?;
317
polars_ensure!(
318
matches!(dtype,
319
DataType::Date |
320
DataType::Datetime(_, _) |
321
DataType::Time
322
),
323
InvalidOperation: "`strptime` expects a `date`, `datetime` or `time` got {dtype}"
324
);
325
IS::Strptime(dtype, strptime_options)
326
},
327
S::Split(v) => IS::Split(v),
328
#[cfg(feature = "dtype-decimal")]
329
S::ToDecimal { scale } => IS::ToDecimal { scale },
330
#[cfg(feature = "nightly")]
331
S::Titlecase => IS::Titlecase,
332
S::Uppercase => IS::Uppercase,
333
#[cfg(feature = "string_pad")]
334
S::ZFill => IS::ZFill,
335
#[cfg(feature = "find_many")]
336
S::ContainsAny {
337
ascii_case_insensitive,
338
} => IS::ContainsAny {
339
ascii_case_insensitive,
340
},
341
#[cfg(feature = "find_many")]
342
S::ReplaceMany {
343
ascii_case_insensitive,
344
leftmost,
345
} => IS::ReplaceMany {
346
ascii_case_insensitive,
347
leftmost,
348
},
349
#[cfg(feature = "find_many")]
350
S::ExtractMany {
351
ascii_case_insensitive,
352
overlapping,
353
leftmost,
354
} => IS::ExtractMany {
355
ascii_case_insensitive,
356
overlapping,
357
leftmost,
358
},
359
#[cfg(feature = "find_many")]
360
S::FindMany {
361
ascii_case_insensitive,
362
overlapping,
363
leftmost,
364
} => IS::FindMany {
365
ascii_case_insensitive,
366
overlapping,
367
leftmost,
368
},
369
#[cfg(feature = "regex")]
370
S::EscapeRegex => IS::EscapeRegex,
371
})
372
},
373
#[cfg(feature = "dtype-struct")]
374
F::StructExpr(struct_function) => {
375
use {IRStructFunction as IS, StructFunction as S};
376
I::StructExpr(match struct_function {
377
S::FieldByName(pl_small_str) => IS::FieldByName(pl_small_str),
378
S::RenameFields(pl_small_strs) => IS::RenameFields(pl_small_strs),
379
S::PrefixFields(pl_small_str) => IS::PrefixFields(pl_small_str),
380
S::SuffixFields(pl_small_str) => IS::SuffixFields(pl_small_str),
381
S::SelectFields(_) => unreachable!("handled by expression expansion"),
382
#[cfg(feature = "json")]
383
S::JsonEncode => IS::JsonEncode,
384
S::WithFields => unreachable!("handled before"),
385
S::MapFieldNames(f) => IS::MapFieldNames(f),
386
})
387
},
388
#[cfg(feature = "temporal")]
389
F::TemporalExpr(temporal_function) => {
390
use {IRTemporalFunction as IT, TemporalFunction as T};
391
I::TemporalExpr(match temporal_function {
392
T::Millennium => IT::Millennium,
393
T::Century => IT::Century,
394
T::Year => IT::Year,
395
T::IsLeapYear => IT::IsLeapYear,
396
T::IsoYear => IT::IsoYear,
397
T::Quarter => IT::Quarter,
398
T::Month => IT::Month,
399
T::DaysInMonth => IT::DaysInMonth,
400
T::Week => IT::Week,
401
T::WeekDay => IT::WeekDay,
402
T::Day => IT::Day,
403
T::OrdinalDay => IT::OrdinalDay,
404
T::Time => IT::Time,
405
T::Date => IT::Date,
406
T::Datetime => IT::Datetime,
407
#[cfg(feature = "dtype-duration")]
408
T::Duration(time_unit) => IT::Duration(time_unit),
409
T::Hour => IT::Hour,
410
T::Minute => IT::Minute,
411
T::Second => IT::Second,
412
T::Millisecond => IT::Millisecond,
413
T::Microsecond => IT::Microsecond,
414
T::Nanosecond => IT::Nanosecond,
415
#[cfg(feature = "dtype-duration")]
416
T::TotalDays { fractional } => IT::TotalDays { fractional },
417
#[cfg(feature = "dtype-duration")]
418
T::TotalHours { fractional } => IT::TotalHours { fractional },
419
#[cfg(feature = "dtype-duration")]
420
T::TotalMinutes { fractional } => IT::TotalMinutes { fractional },
421
#[cfg(feature = "dtype-duration")]
422
T::TotalSeconds { fractional } => IT::TotalSeconds { fractional },
423
#[cfg(feature = "dtype-duration")]
424
T::TotalMilliseconds { fractional } => IT::TotalMilliseconds { fractional },
425
#[cfg(feature = "dtype-duration")]
426
T::TotalMicroseconds { fractional } => IT::TotalMicroseconds { fractional },
427
#[cfg(feature = "dtype-duration")]
428
T::TotalNanoseconds { fractional } => IT::TotalNanoseconds { fractional },
429
T::ToString(v) => IT::ToString(v),
430
T::CastTimeUnit(time_unit) => IT::CastTimeUnit(time_unit),
431
T::WithTimeUnit(time_unit) => IT::WithTimeUnit(time_unit),
432
#[cfg(feature = "timezones")]
433
T::ConvertTimeZone(time_zone) => IT::ConvertTimeZone(time_zone),
434
T::TimeStamp(time_unit) => IT::TimeStamp(time_unit),
435
T::Truncate => IT::Truncate,
436
#[cfg(feature = "offset_by")]
437
T::OffsetBy => IT::OffsetBy,
438
#[cfg(feature = "month_start")]
439
T::MonthStart => IT::MonthStart,
440
#[cfg(feature = "month_end")]
441
T::MonthEnd => IT::MonthEnd,
442
#[cfg(feature = "timezones")]
443
T::BaseUtcOffset => IT::BaseUtcOffset,
444
#[cfg(feature = "timezones")]
445
T::DSTOffset => IT::DSTOffset,
446
T::Round => IT::Round,
447
T::Replace => IT::Replace,
448
#[cfg(feature = "timezones")]
449
T::ReplaceTimeZone(time_zone, non_existent) => {
450
IT::ReplaceTimeZone(time_zone, non_existent)
451
},
452
T::Combine(time_unit) => IT::Combine(time_unit),
453
T::DatetimeFunction {
454
time_unit,
455
time_zone,
456
} => IT::DatetimeFunction {
457
time_unit,
458
time_zone,
459
},
460
})
461
},
462
#[cfg(feature = "bitwise")]
463
F::Bitwise(bitwise_function) => I::Bitwise(match bitwise_function {
464
BitwiseFunction::CountOnes => IRBitwiseFunction::CountOnes,
465
BitwiseFunction::CountZeros => IRBitwiseFunction::CountZeros,
466
BitwiseFunction::LeadingOnes => IRBitwiseFunction::LeadingOnes,
467
BitwiseFunction::LeadingZeros => IRBitwiseFunction::LeadingZeros,
468
BitwiseFunction::TrailingOnes => IRBitwiseFunction::TrailingOnes,
469
BitwiseFunction::TrailingZeros => IRBitwiseFunction::TrailingZeros,
470
BitwiseFunction::And => IRBitwiseFunction::And,
471
BitwiseFunction::Or => IRBitwiseFunction::Or,
472
BitwiseFunction::Xor => IRBitwiseFunction::Xor,
473
}),
474
F::Boolean(boolean_function) => {
475
use {BooleanFunction as B, IRBooleanFunction as IB};
476
I::Boolean(match boolean_function {
477
B::Any { ignore_nulls } => IB::Any { ignore_nulls },
478
B::All { ignore_nulls } => IB::All { ignore_nulls },
479
B::IsNull => IB::IsNull,
480
B::IsNotNull => IB::IsNotNull,
481
B::IsFinite => IB::IsFinite,
482
B::IsInfinite => IB::IsInfinite,
483
B::IsNan => IB::IsNan,
484
B::IsNotNan => IB::IsNotNan,
485
#[cfg(feature = "is_first_distinct")]
486
B::IsFirstDistinct => IB::IsFirstDistinct,
487
#[cfg(feature = "is_last_distinct")]
488
B::IsLastDistinct => IB::IsLastDistinct,
489
#[cfg(feature = "is_unique")]
490
B::IsUnique => IB::IsUnique,
491
#[cfg(feature = "is_unique")]
492
B::IsDuplicated => IB::IsDuplicated,
493
#[cfg(feature = "is_between")]
494
B::IsBetween { closed } => IB::IsBetween { closed },
495
#[cfg(feature = "is_in")]
496
B::IsIn { nulls_equal } => IB::IsIn { nulls_equal },
497
#[cfg(feature = "is_close")]
498
B::IsClose {
499
abs_tol,
500
rel_tol,
501
nans_equal,
502
} => IB::IsClose {
503
abs_tol,
504
rel_tol,
505
nans_equal,
506
},
507
B::AllHorizontal => {
508
let Some(fst) = e.first() else {
509
return Ok((
510
ctx.arena.add(AExpr::Literal(Scalar::from(true).into())),
511
format_pl_smallstr!("{}", IB::AllHorizontal),
512
));
513
};
514
515
if e.len() == 1 {
516
return Ok((
517
AExprBuilder::new_from_node(fst.node())
518
.cast(DataType::Boolean, ctx.arena)
519
.node(),
520
fst.output_name().clone(),
521
));
522
}
523
524
// Convert to binary expression as the optimizer understands those.
525
// Don't exceed 128 expressions as we might stackoverflow.
526
if e.len() < 128 {
527
let mut r = AExprBuilder::new_from_node(fst.node());
528
for expr in &e[1..] {
529
r = r.logical_and(expr.node(), ctx.arena);
530
}
531
return Ok((r.node(), fst.output_name().clone()));
532
}
533
534
IB::AllHorizontal
535
},
536
B::AnyHorizontal => {
537
// This can be created by col(*).is_null() on empty dataframes.
538
let Some(fst) = e.first() else {
539
return Ok((
540
ctx.arena.add(AExpr::Literal(Scalar::from(false).into())),
541
format_pl_smallstr!("{}", IB::AnyHorizontal),
542
));
543
};
544
545
if e.len() == 1 {
546
return Ok((
547
AExprBuilder::new_from_node(fst.node())
548
.cast(DataType::Boolean, ctx.arena)
549
.node(),
550
fst.output_name().clone(),
551
));
552
}
553
554
// Convert to binary expression as the optimizer understands those.
555
// Don't exceed 128 expressions as we might stackoverflow.
556
if e.len() < 128 {
557
let mut r = AExprBuilder::new_from_node(fst.node());
558
for expr in &e[1..] {
559
r = r.logical_or(expr.node(), ctx.arena);
560
}
561
return Ok((r.node(), fst.output_name().clone()));
562
}
563
564
IB::AnyHorizontal
565
},
566
B::Not => IB::Not,
567
})
568
},
569
#[cfg(feature = "business")]
570
F::Business(business_function) => I::Business(match business_function {
571
BusinessFunction::BusinessDayCount {
572
week_mask,
573
holidays,
574
} => IRBusinessFunction::BusinessDayCount {
575
week_mask,
576
holidays,
577
},
578
BusinessFunction::AddBusinessDay {
579
week_mask,
580
holidays,
581
roll,
582
} => IRBusinessFunction::AddBusinessDay {
583
week_mask,
584
holidays,
585
roll,
586
},
587
BusinessFunction::IsBusinessDay {
588
week_mask,
589
holidays,
590
} => IRBusinessFunction::IsBusinessDay {
591
week_mask,
592
holidays,
593
},
594
}),
595
#[cfg(feature = "abs")]
596
F::Abs => I::Abs,
597
F::Negate => I::Negate,
598
#[cfg(feature = "hist")]
599
F::Hist {
600
bin_count,
601
include_category,
602
include_breakpoint,
603
} => I::Hist {
604
bin_count,
605
include_category,
606
include_breakpoint,
607
},
608
F::NullCount => I::NullCount,
609
F::Pow(pow_function) => I::Pow(match pow_function {
610
PowFunction::Generic => IRPowFunction::Generic,
611
PowFunction::Sqrt => IRPowFunction::Sqrt,
612
PowFunction::Cbrt => IRPowFunction::Cbrt,
613
}),
614
#[cfg(feature = "row_hash")]
615
F::Hash(s0, s1, s2, s3) => I::Hash(s0, s1, s2, s3),
616
#[cfg(feature = "arg_where")]
617
F::ArgWhere => I::ArgWhere,
618
#[cfg(feature = "index_of")]
619
F::IndexOf => I::IndexOf,
620
#[cfg(feature = "search_sorted")]
621
F::SearchSorted { side, descending } => I::SearchSorted { side, descending },
622
#[cfg(feature = "range")]
623
F::Range(range_function) => I::Range(match range_function {
624
RangeFunction::IntRange { step, dtype } => {
625
let dtype = dtype.into_datatype(ctx.schema)?;
626
polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `int_range`");
627
polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar stop passed to `int_range`");
628
polars_ensure!(dtype.is_integer(), SchemaMismatch: "non-integer `dtype` passed to `int_range`: '{dtype}'");
629
IRRangeFunction::IntRange { step, dtype }
630
},
631
RangeFunction::IntRanges { dtype } => {
632
let dtype = dtype.into_datatype(ctx.schema)?;
633
polars_ensure!(dtype.is_integer(), SchemaMismatch: "non-integer `dtype` passed to `int_ranges`: '{dtype}'");
634
IRRangeFunction::IntRanges { dtype }
635
},
636
RangeFunction::LinearSpace { closed } => {
637
polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `linear_space`");
638
polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar end passed to `linear_space`");
639
polars_ensure!(e[2].is_scalar(ctx.arena), ShapeMismatch: "non-scalar num_samples passed to `linear_space`");
640
IRRangeFunction::LinearSpace { closed }
641
},
642
RangeFunction::LinearSpaces {
643
closed,
644
array_width,
645
} => IRRangeFunction::LinearSpaces {
646
closed,
647
array_width,
648
},
649
#[cfg(all(feature = "range", feature = "dtype-date"))]
650
RangeFunction::DateRange {
651
interval,
652
closed,
653
arg_type,
654
} => {
655
use DateRangeArgs::*;
656
let arg_names = match arg_type {
657
StartEndSamples => vec!["start", "end", "num_samples"],
658
StartEndInterval => vec!["start", "end"],
659
StartIntervalSamples => vec!["start", "num_samples"],
660
EndIntervalSamples => vec!["end", "num_samples"],
661
};
662
for (idx, &name) in arg_names.iter().enumerate() {
663
polars_ensure!(e[idx].is_scalar(ctx.arena), ShapeMismatch: "non-scalar {name} passed to `date_range`");
664
}
665
IRRangeFunction::DateRange {
666
interval,
667
closed,
668
arg_type,
669
}
670
},
671
#[cfg(all(feature = "range", feature = "dtype-date"))]
672
RangeFunction::DateRanges {
673
interval,
674
closed,
675
arg_type,
676
} => IRRangeFunction::DateRanges {
677
interval,
678
closed,
679
arg_type,
680
},
681
#[cfg(all(feature = "range", feature = "dtype-datetime"))]
682
RangeFunction::DatetimeRange {
683
interval,
684
closed,
685
time_unit,
686
time_zone,
687
arg_type,
688
} => {
689
use DateRangeArgs::*;
690
let arg_names = match arg_type {
691
StartEndSamples => vec!["start", "end", "num_samples"],
692
StartEndInterval => vec!["start", "end"],
693
StartIntervalSamples => vec!["start", "num_samples"],
694
EndIntervalSamples => vec!["end", "num_samples"],
695
};
696
for (idx, &name) in arg_names.iter().enumerate() {
697
polars_ensure!(e[idx].is_scalar(ctx.arena), ShapeMismatch: "non-scalar {name} passed to `datetime_range`");
698
}
699
IRRangeFunction::DatetimeRange {
700
interval,
701
closed,
702
time_unit,
703
time_zone,
704
arg_type,
705
}
706
},
707
#[cfg(all(feature = "range", feature = "dtype-datetime"))]
708
RangeFunction::DatetimeRanges {
709
interval,
710
closed,
711
time_unit,
712
time_zone,
713
arg_type,
714
} => IRRangeFunction::DatetimeRanges {
715
interval,
716
closed,
717
time_unit,
718
time_zone,
719
arg_type,
720
},
721
#[cfg(all(feature = "range", feature = "dtype-time"))]
722
RangeFunction::TimeRange { interval, closed } => {
723
polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `time_range`");
724
polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar end passed to `time_range`");
725
IRRangeFunction::TimeRange { interval, closed }
726
},
727
#[cfg(all(feature = "range", feature = "dtype-time"))]
728
RangeFunction::TimeRanges { interval, closed } => {
729
IRRangeFunction::TimeRanges { interval, closed }
730
},
731
}),
732
#[cfg(feature = "trigonometry")]
733
F::Trigonometry(trigonometric_function) => {
734
use {IRTrigonometricFunction as IT, TrigonometricFunction as T};
735
I::Trigonometry(match trigonometric_function {
736
T::Cos => IT::Cos,
737
T::Cot => IT::Cot,
738
T::Sin => IT::Sin,
739
T::Tan => IT::Tan,
740
T::ArcCos => IT::ArcCos,
741
T::ArcSin => IT::ArcSin,
742
T::ArcTan => IT::ArcTan,
743
T::Cosh => IT::Cosh,
744
T::Sinh => IT::Sinh,
745
T::Tanh => IT::Tanh,
746
T::ArcCosh => IT::ArcCosh,
747
T::ArcSinh => IT::ArcSinh,
748
T::ArcTanh => IT::ArcTanh,
749
T::Degrees => IT::Degrees,
750
T::Radians => IT::Radians,
751
})
752
},
753
#[cfg(feature = "trigonometry")]
754
F::Atan2 => I::Atan2,
755
#[cfg(feature = "sign")]
756
F::Sign => I::Sign,
757
F::FillNull => I::FillNull,
758
F::FillNullWithStrategy(fill_null_strategy) => I::FillNullWithStrategy(fill_null_strategy),
759
#[cfg(feature = "rolling_window")]
760
F::RollingExpr { function, options } => {
761
use RollingFunction as R;
762
use aexpr::IRRollingFunction as IR;
763
764
I::RollingExpr {
765
function: match function {
766
R::Min => IR::Min,
767
R::Max => IR::Max,
768
R::Mean => IR::Mean,
769
R::Sum => IR::Sum,
770
R::Quantile => IR::Quantile,
771
R::Var => IR::Var,
772
R::Std => IR::Std,
773
R::Rank => IR::Rank,
774
#[cfg(feature = "moment")]
775
R::Skew => IR::Skew,
776
#[cfg(feature = "moment")]
777
R::Kurtosis => IR::Kurtosis,
778
#[cfg(feature = "cov")]
779
R::CorrCov {
780
corr_cov_options,
781
is_corr,
782
} => IR::CorrCov {
783
corr_cov_options,
784
is_corr,
785
},
786
R::Map(f) => IR::Map(f),
787
},
788
options,
789
}
790
},
791
#[cfg(feature = "rolling_window_by")]
792
F::RollingExprBy {
793
function_by,
794
options,
795
} => {
796
use RollingFunctionBy as R;
797
use aexpr::IRRollingFunctionBy as IR;
798
799
I::RollingExprBy {
800
function_by: match function_by {
801
R::MinBy => IR::MinBy,
802
R::MaxBy => IR::MaxBy,
803
R::MeanBy => IR::MeanBy,
804
R::SumBy => IR::SumBy,
805
R::QuantileBy => IR::QuantileBy,
806
R::VarBy => IR::VarBy,
807
R::StdBy => IR::StdBy,
808
R::RankBy => IR::RankBy,
809
},
810
options,
811
}
812
},
813
F::Rechunk => I::Rechunk,
814
F::Append { upcast } => I::Append { upcast },
815
F::ShiftAndFill => {
816
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
817
polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'fill_value' must be a scalar value");
818
I::ShiftAndFill
819
},
820
F::Shift => {
821
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
822
I::Shift
823
},
824
F::DropNans => I::DropNans,
825
F::DropNulls => I::DropNulls,
826
#[cfg(feature = "mode")]
827
F::Mode { maintain_order } => I::Mode { maintain_order },
828
#[cfg(feature = "moment")]
829
F::Skew(v) => I::Skew(v),
830
#[cfg(feature = "moment")]
831
F::Kurtosis(l, r) => I::Kurtosis(l, r),
832
#[cfg(feature = "dtype-array")]
833
F::Reshape(reshape_dimensions) => I::Reshape(reshape_dimensions),
834
#[cfg(feature = "repeat_by")]
835
F::RepeatBy => I::RepeatBy,
836
F::ArgUnique => I::ArgUnique,
837
F::ArgMin => I::ArgMin,
838
F::ArgMax => I::ArgMax,
839
F::ArgSort {
840
descending,
841
nulls_last,
842
} => I::ArgSort {
843
descending,
844
nulls_last,
845
},
846
F::Product => I::Product,
847
#[cfg(feature = "rank")]
848
F::Rank { options, seed } => I::Rank { options, seed },
849
F::Repeat => {
850
polars_ensure!(&e[0].is_scalar(ctx.arena), ShapeMismatch: "'value' must be a scalar value");
851
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
852
I::Repeat
853
},
854
#[cfg(feature = "round_series")]
855
F::Clip { has_min, has_max } => I::Clip { has_min, has_max },
856
#[cfg(feature = "dtype-struct")]
857
F::AsStruct => I::AsStruct,
858
#[cfg(feature = "top_k")]
859
F::TopK { descending } => I::TopK { descending },
860
#[cfg(feature = "top_k")]
861
F::TopKBy { descending } => I::TopKBy { descending },
862
#[cfg(feature = "cum_agg")]
863
F::CumCount { reverse } => I::CumCount { reverse },
864
#[cfg(feature = "cum_agg")]
865
F::CumSum { reverse } => I::CumSum { reverse },
866
#[cfg(feature = "cum_agg")]
867
F::CumProd { reverse } => I::CumProd { reverse },
868
#[cfg(feature = "cum_agg")]
869
F::CumMin { reverse } => I::CumMin { reverse },
870
#[cfg(feature = "cum_agg")]
871
F::CumMax { reverse } => I::CumMax { reverse },
872
F::Reverse => I::Reverse,
873
#[cfg(feature = "dtype-struct")]
874
F::ValueCounts {
875
sort,
876
parallel,
877
name,
878
normalize,
879
} => I::ValueCounts {
880
sort,
881
parallel,
882
name,
883
normalize,
884
},
885
#[cfg(feature = "unique_counts")]
886
F::UniqueCounts => I::UniqueCounts,
887
#[cfg(feature = "approx_unique")]
888
F::ApproxNUnique => I::ApproxNUnique,
889
F::Coalesce => I::Coalesce,
890
#[cfg(feature = "diff")]
891
F::Diff(n) => {
892
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
893
I::Diff(n)
894
},
895
#[cfg(feature = "pct_change")]
896
F::PctChange => I::PctChange,
897
#[cfg(feature = "interpolate")]
898
F::Interpolate(interpolation_method) => I::Interpolate(interpolation_method),
899
#[cfg(feature = "interpolate_by")]
900
F::InterpolateBy => I::InterpolateBy,
901
#[cfg(feature = "log")]
902
F::Entropy { base, normalize } => I::Entropy { base, normalize },
903
#[cfg(feature = "log")]
904
F::Log => I::Log,
905
#[cfg(feature = "log")]
906
F::Log1p => I::Log1p,
907
#[cfg(feature = "log")]
908
F::Exp => I::Exp,
909
F::Unique(v) => I::Unique(v),
910
#[cfg(feature = "round_series")]
911
F::Round { decimals, mode } => I::Round { decimals, mode },
912
#[cfg(feature = "round_series")]
913
F::RoundSF { digits } => I::RoundSF { digits },
914
#[cfg(feature = "round_series")]
915
F::Floor => I::Floor,
916
#[cfg(feature = "round_series")]
917
F::Ceil => I::Ceil,
918
F::UpperBound => {
919
let field = e[0].field(ctx.schema, ctx.arena)?;
920
return Ok((
921
ctx.arena
922
.add(AExpr::Literal(field.dtype.to_physical().max()?.into())),
923
field.name,
924
));
925
},
926
F::LowerBound => {
927
let field = e[0].field(ctx.schema, ctx.arena)?;
928
return Ok((
929
ctx.arena
930
.add(AExpr::Literal(field.dtype.to_physical().min()?.into())),
931
field.name,
932
));
933
},
934
F::ConcatExpr(v) => I::ConcatExpr(v),
935
#[cfg(feature = "cov")]
936
F::Correlation { method } => {
937
use {CorrelationMethod as C, IRCorrelationMethod as IC};
938
I::Correlation {
939
method: match method {
940
C::Pearson => IC::Pearson,
941
#[cfg(all(feature = "rank", feature = "propagate_nans"))]
942
C::SpearmanRank(v) => IC::SpearmanRank(v),
943
C::Covariance(v) => IC::Covariance(v),
944
},
945
}
946
},
947
#[cfg(feature = "peaks")]
948
F::PeakMin => I::PeakMin,
949
#[cfg(feature = "peaks")]
950
F::PeakMax => I::PeakMax,
951
#[cfg(feature = "cutqcut")]
952
F::Cut {
953
breaks,
954
labels,
955
left_closed,
956
include_breaks,
957
} => I::Cut {
958
breaks,
959
labels,
960
left_closed,
961
include_breaks,
962
},
963
#[cfg(feature = "cutqcut")]
964
F::QCut {
965
probs,
966
labels,
967
left_closed,
968
allow_duplicates,
969
include_breaks,
970
} => I::QCut {
971
probs,
972
labels,
973
left_closed,
974
allow_duplicates,
975
include_breaks,
976
},
977
#[cfg(feature = "rle")]
978
F::RLE => I::RLE,
979
#[cfg(feature = "rle")]
980
F::RLEID => I::RLEID,
981
F::ToPhysical => I::ToPhysical,
982
#[cfg(feature = "random")]
983
F::Random { method, seed } => {
984
use {IRRandomMethod as IR, RandomMethod as R};
985
I::Random {
986
method: match method {
987
R::Shuffle => IR::Shuffle,
988
R::Sample {
989
is_fraction,
990
with_replacement,
991
shuffle,
992
} => IR::Sample {
993
is_fraction,
994
with_replacement,
995
shuffle,
996
},
997
},
998
seed,
999
}
1000
},
1001
F::SetSortedFlag(is_sorted) => I::SetSortedFlag(is_sorted),
1002
#[cfg(feature = "ffi_plugin")]
1003
F::FfiPlugin {
1004
flags,
1005
lib,
1006
symbol,
1007
kwargs,
1008
} => I::FfiPlugin {
1009
flags,
1010
lib,
1011
symbol,
1012
kwargs,
1013
},
1014
1015
F::FoldHorizontal {
1016
callback,
1017
returns_scalar,
1018
return_dtype,
1019
} => I::FoldHorizontal {
1020
callback,
1021
returns_scalar,
1022
return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,
1023
},
1024
F::ReduceHorizontal {
1025
callback,
1026
returns_scalar,
1027
return_dtype,
1028
} => I::ReduceHorizontal {
1029
callback,
1030
returns_scalar,
1031
return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,
1032
},
1033
#[cfg(feature = "dtype-struct")]
1034
F::CumReduceHorizontal {
1035
callback,
1036
returns_scalar,
1037
return_dtype,
1038
} => I::CumReduceHorizontal {
1039
callback,
1040
returns_scalar,
1041
return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,
1042
},
1043
#[cfg(feature = "dtype-struct")]
1044
F::CumFoldHorizontal {
1045
callback,
1046
returns_scalar,
1047
return_dtype,
1048
include_init,
1049
} => I::CumFoldHorizontal {
1050
callback,
1051
returns_scalar,
1052
return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,
1053
include_init,
1054
},
1055
1056
F::MaxHorizontal => I::MaxHorizontal,
1057
F::MinHorizontal => I::MinHorizontal,
1058
F::SumHorizontal { ignore_nulls } => I::SumHorizontal { ignore_nulls },
1059
F::MeanHorizontal { ignore_nulls } => I::MeanHorizontal { ignore_nulls },
1060
#[cfg(feature = "ewma")]
1061
F::EwmMean { options } => I::EwmMean { options },
1062
#[cfg(feature = "ewma_by")]
1063
F::EwmMeanBy { half_life } => I::EwmMeanBy { half_life },
1064
#[cfg(feature = "ewma")]
1065
F::EwmStd { options } => I::EwmStd { options },
1066
#[cfg(feature = "ewma")]
1067
F::EwmVar { options } => I::EwmVar { options },
1068
#[cfg(feature = "replace")]
1069
F::Replace => I::Replace,
1070
#[cfg(feature = "replace")]
1071
F::ReplaceStrict { return_dtype } => I::ReplaceStrict {
1072
return_dtype: match return_dtype {
1073
Some(dtype) => Some(dtype.into_datatype(ctx.schema)?),
1074
None => None,
1075
},
1076
},
1077
F::GatherEvery { n, offset } => I::GatherEvery { n, offset },
1078
#[cfg(feature = "reinterpret")]
1079
F::Reinterpret(v) => I::Reinterpret(v),
1080
F::ExtendConstant => {
1081
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'value' must be a scalar value");
1082
polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
1083
I::ExtendConstant
1084
},
1085
1086
F::RowEncode(v) => {
1087
let dts = e
1088
.iter()
1089
.map(|e| Ok(e.dtype(ctx.schema, ctx.arena)?.clone()))
1090
.collect::<PolarsResult<Vec<_>>>()?;
1091
I::RowEncode(dts, v)
1092
},
1093
#[cfg(feature = "dtype-struct")]
1094
F::RowDecode(fs, v) => I::RowDecode(
1095
fs.into_iter()
1096
.map(|(name, dt_expr)| Ok(Field::new(name, dt_expr.into_datatype(ctx.schema)?)))
1097
.collect::<PolarsResult<Vec<_>>>()?,
1098
v,
1099
),
1100
};
1101
1102
let mut options = ir_function.function_options();
1103
if set_elementwise {
1104
options.set_elementwise();
1105
}
1106
1107
// Handles special case functions like `struct.field`.
1108
let output_name = match ir_function.output_name().and_then(|v| v.into_inner()) {
1109
Some(name) => name,
1110
None if e.is_empty() => format_pl_smallstr!("{}", &ir_function),
1111
None => e[0].output_name().clone(),
1112
};
1113
1114
let ae_function = AExpr::Function {
1115
input: e,
1116
function: ir_function,
1117
options,
1118
};
1119
Ok((ctx.arena.add(ae_function), output_name))
1120
}
1121
1122