Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs
8503 views
1
use arrow::legacy::error::PolarsResult;
2
use polars_utils::arena::Node;
3
use polars_utils::format_pl_smallstr;
4
use polars_utils::option::OptionTry;
5
6
use super::expr_to_ir::ExprToIRContext;
7
use super::*;
8
use crate::constants::get_literal_name;
9
use crate::dsl::{Expr, FunctionExpr};
10
use crate::plans::conversion::dsl_to_ir::expr_to_ir::to_expr_irs;
11
use crate::plans::{AExpr, IRFunctionExpr};
12
13
pub(super) fn convert_functions(
14
input: Vec<Expr>,
15
function: FunctionExpr,
16
ctx: &mut ExprToIRContext,
17
) -> PolarsResult<(Node, PlSmallStr)> {
18
use {FunctionExpr as F, IRFunctionExpr as I};
19
20
// Converts inputs
21
let input_is_empty = input.is_empty();
22
let e = to_expr_irs(input, ctx)?;
23
let mut set_elementwise = false;
24
25
// Return before converting inputs
26
let ir_function = match function {
27
#[cfg(feature = "dtype-array")]
28
F::ArrayExpr(array_function) => {
29
use {ArrayFunction as A, IRArrayFunction as IA};
30
I::ArrayExpr(match array_function {
31
A::Length => IA::Length,
32
A::Min => IA::Min,
33
A::Max => IA::Max,
34
A::Sum => IA::Sum,
35
A::ToList => IA::ToList,
36
A::Unique(stable) => IA::Unique(stable),
37
A::NUnique => IA::NUnique,
38
A::Std(v) => IA::Std(v),
39
A::Var(v) => IA::Var(v),
40
A::Mean => IA::Mean,
41
A::Median => IA::Median,
42
#[cfg(feature = "array_any_all")]
43
A::Any => IA::Any,
44
#[cfg(feature = "array_any_all")]
45
A::All => IA::All,
46
A::Sort(sort_options) => IA::Sort(sort_options),
47
A::Reverse => IA::Reverse,
48
A::ArgMin => IA::ArgMin,
49
A::ArgMax => IA::ArgMax,
50
A::Get(v) => IA::Get(v),
51
A::Join(v) => IA::Join(v),
52
#[cfg(feature = "is_in")]
53
A::Contains { nulls_equal } => IA::Contains { nulls_equal },
54
#[cfg(feature = "array_count")]
55
A::CountMatches => IA::CountMatches,
56
A::Shift => IA::Shift,
57
A::Explode(options) => IA::Explode(options),
58
A::Concat => IA::Concat,
59
A::Slice(offset, length) => IA::Slice(offset, length),
60
#[cfg(feature = "array_to_struct")]
61
A::ToStruct(ng) => IA::ToStruct(ng),
62
})
63
},
64
F::BinaryExpr(binary_function) => {
65
use {BinaryFunction as B, IRBinaryFunction as IB};
66
I::BinaryExpr(match binary_function {
67
B::Contains => IB::Contains,
68
B::StartsWith => IB::StartsWith,
69
B::EndsWith => IB::EndsWith,
70
#[cfg(feature = "binary_encoding")]
71
B::HexDecode(v) => IB::HexDecode(v),
72
#[cfg(feature = "binary_encoding")]
73
B::HexEncode => IB::HexEncode,
74
#[cfg(feature = "binary_encoding")]
75
B::Base64Decode(v) => IB::Base64Decode(v),
76
#[cfg(feature = "binary_encoding")]
77
B::Base64Encode => IB::Base64Encode,
78
B::Size => IB::Size,
79
#[cfg(feature = "binary_encoding")]
80
B::Reinterpret(dtype_expr, v) => {
81
let dtype = dtype_expr.into_datatype(ctx.schema)?;
82
let can_reinterpret_to =
83
|dt: &DataType| dt.is_primitive_numeric() || dt.is_temporal();
84
polars_ensure!(
85
can_reinterpret_to(&dtype) || (
86
dtype.is_array() && dtype.inner_dtype().map(can_reinterpret_to) == Some(true)
87
),
88
InvalidOperation:
89
"cannot reinterpret binary to dtype {:?}. Only numeric or temporal dtype, or Arrays of these, are supported. Hint: To reinterpret to a nested Array, first reinterpret to a linear Array, and then use reshape",
90
dtype
91
);
92
IB::Reinterpret(dtype, v)
93
},
94
B::Slice => IB::Slice,
95
B::Head => IB::Head,
96
B::Tail => IB::Tail,
97
B::Get(null_on_oob) => IB::Get(null_on_oob),
98
})
99
},
100
#[cfg(feature = "dtype-categorical")]
101
F::Categorical(categorical_function) => {
102
use {CategoricalFunction as C, IRCategoricalFunction as IC};
103
I::Categorical(match categorical_function {
104
C::GetCategories => IC::GetCategories,
105
#[cfg(feature = "strings")]
106
C::LenBytes => IC::LenBytes,
107
#[cfg(feature = "strings")]
108
C::LenChars => IC::LenChars,
109
#[cfg(feature = "strings")]
110
C::StartsWith(v) => IC::StartsWith(v),
111
#[cfg(feature = "strings")]
112
C::EndsWith(v) => IC::EndsWith(v),
113
#[cfg(feature = "strings")]
114
C::Slice(s, e) => IC::Slice(s, e),
115
})
116
},
117
#[cfg(feature = "dtype-extension")]
118
F::Extension(extension_function) => {
119
use {ExtensionFunction as E, IRExtensionFunction as IE};
120
I::Extension(match extension_function {
121
E::To(dtype) => {
122
let concrete_dtype = dtype.into_datatype(ctx.schema)?;
123
polars_ensure!(matches!(concrete_dtype, DataType::Extension(_, _)),
124
InvalidOperation: "ext.to() requires an Extension dtype, got {concrete_dtype:?}"
125
);
126
IE::To(concrete_dtype)
127
},
128
E::Storage => IE::Storage,
129
})
130
},
131
F::ListExpr(list_function) => {
132
use {IRListFunction as IL, ListFunction as L};
133
I::ListExpr(match list_function {
134
L::Concat => IL::Concat,
135
#[cfg(feature = "is_in")]
136
L::Contains { nulls_equal } => IL::Contains { nulls_equal },
137
#[cfg(feature = "list_drop_nulls")]
138
L::DropNulls => IL::DropNulls,
139
#[cfg(feature = "list_sample")]
140
L::Sample {
141
is_fraction,
142
with_replacement,
143
shuffle,
144
seed,
145
} => IL::Sample {
146
is_fraction,
147
with_replacement,
148
shuffle,
149
seed,
150
},
151
L::Slice => IL::Slice,
152
L::Shift => IL::Shift,
153
L::Get(v) => IL::Get(v),
154
#[cfg(feature = "list_gather")]
155
L::Gather(v) => IL::Gather(v),
156
#[cfg(feature = "list_gather")]
157
L::GatherEvery => IL::GatherEvery,
158
#[cfg(feature = "list_count")]
159
L::CountMatches => IL::CountMatches,
160
L::Sum => IL::Sum,
161
L::Length => IL::Length,
162
L::Max => IL::Max,
163
L::Min => IL::Min,
164
L::Mean => IL::Mean,
165
L::Median => IL::Median,
166
L::Std(v) => IL::Std(v),
167
L::Var(v) => IL::Var(v),
168
L::ArgMin => IL::ArgMin,
169
L::ArgMax => IL::ArgMax,
170
#[cfg(feature = "diff")]
171
L::Diff { n, null_behavior } => IL::Diff { n, null_behavior },
172
L::Sort(sort_options) => IL::Sort(sort_options),
173
L::Reverse => IL::Reverse,
174
L::Unique(v) => IL::Unique(v),
175
L::NUnique => IL::NUnique,
176
#[cfg(feature = "list_sets")]
177
L::SetOperation(set_operation) => IL::SetOperation(set_operation),
178
#[cfg(feature = "list_any_all")]
179
L::Any => IL::Any,
180
#[cfg(feature = "list_any_all")]
181
L::All => IL::All,
182
L::Join(v) => IL::Join(v),
183
#[cfg(feature = "dtype-array")]
184
L::ToArray(v) => IL::ToArray(v),
185
#[cfg(feature = "list_to_struct")]
186
L::ToStruct(list_to_struct_args) => IL::ToStruct(list_to_struct_args),
187
})
188
},
189
#[cfg(feature = "strings")]
190
F::StringExpr(string_function) => {
191
use {IRStringFunction as IS, StringFunction as S};
192
I::StringExpr(match string_function {
193
S::Format { format, insertions } => {
194
if input_is_empty {
195
polars_ensure!(
196
insertions.is_empty(),
197
ComputeError: "StringFormat didn't get any inputs, format: \"{}\"",
198
format
199
);
200
201
let out = ctx
202
.arena
203
.add(AExpr::Literal(LiteralValue::Scalar(Scalar::from(format))));
204
205
return Ok((out, get_literal_name()));
206
} else {
207
IS::Format { format, insertions }
208
}
209
},
210
#[cfg(feature = "concat_str")]
211
S::ConcatHorizontal {
212
delimiter,
213
ignore_nulls,
214
} => IS::ConcatHorizontal {
215
delimiter,
216
ignore_nulls,
217
},
218
#[cfg(feature = "concat_str")]
219
S::ConcatVertical {
220
delimiter,
221
ignore_nulls,
222
} => IS::ConcatVertical {
223
delimiter,
224
ignore_nulls,
225
},
226
#[cfg(feature = "regex")]
227
S::Contains { literal, strict } => IS::Contains { literal, strict },
228
S::CountMatches(v) => IS::CountMatches(v),
229
S::EndsWith => IS::EndsWith,
230
S::Extract(v) => IS::Extract(v),
231
S::ExtractAll => IS::ExtractAll,
232
#[cfg(feature = "extract_groups")]
233
S::ExtractGroups { dtype, pat } => IS::ExtractGroups { dtype, pat },
234
#[cfg(feature = "regex")]
235
S::Find { literal, strict } => IS::Find { literal, strict },
236
#[cfg(feature = "string_to_integer")]
237
S::ToInteger { dtype, strict } => IS::ToInteger { dtype, strict },
238
S::LenBytes => IS::LenBytes,
239
S::LenChars => IS::LenChars,
240
S::Lowercase => IS::Lowercase,
241
#[cfg(feature = "extract_jsonpath")]
242
S::JsonDecode(dtype) => IS::JsonDecode(dtype.into_datatype(ctx.schema)?),
243
#[cfg(feature = "extract_jsonpath")]
244
S::JsonPathMatch => IS::JsonPathMatch,
245
#[cfg(feature = "regex")]
246
S::Replace { n, literal } => IS::Replace { n, literal },
247
#[cfg(feature = "string_normalize")]
248
S::Normalize { form } => IS::Normalize { form },
249
#[cfg(feature = "string_reverse")]
250
S::Reverse => IS::Reverse,
251
#[cfg(feature = "string_pad")]
252
S::PadStart { fill_char } => IS::PadStart { fill_char },
253
#[cfg(feature = "string_pad")]
254
S::PadEnd { fill_char } => IS::PadEnd { fill_char },
255
S::Slice => IS::Slice,
256
S::Head => IS::Head,
257
S::Tail => IS::Tail,
258
#[cfg(feature = "string_encoding")]
259
S::HexEncode => IS::HexEncode,
260
#[cfg(feature = "binary_encoding")]
261
S::HexDecode(v) => IS::HexDecode(v),
262
#[cfg(feature = "string_encoding")]
263
S::Base64Encode => IS::Base64Encode,
264
#[cfg(feature = "binary_encoding")]
265
S::Base64Decode(v) => IS::Base64Decode(v),
266
S::StartsWith => IS::StartsWith,
267
S::StripChars => IS::StripChars,
268
S::StripCharsStart => IS::StripCharsStart,
269
S::StripCharsEnd => IS::StripCharsEnd,
270
S::StripPrefix => IS::StripPrefix,
271
S::StripSuffix => IS::StripSuffix,
272
#[cfg(feature = "dtype-struct")]
273
S::SplitExact { n, inclusive } => IS::SplitExact { n, inclusive },
274
#[cfg(feature = "dtype-struct")]
275
S::SplitN(v) => IS::SplitN(v),
276
#[cfg(feature = "regex")]
277
S::SplitRegex { inclusive, strict } => IS::SplitRegex { inclusive, strict },
278
#[cfg(feature = "temporal")]
279
S::Strptime(data_type, strptime_options) => {
280
let is_column_independent = is_column_independent_aexpr(e[0].node(), ctx.arena);
281
set_elementwise = is_column_independent;
282
let dtype = data_type.into_datatype(ctx.schema)?;
283
polars_ensure!(
284
matches!(dtype,
285
DataType::Date |
286
DataType::Datetime(_, _) |
287
DataType::Time
288
),
289
InvalidOperation: "`strptime` expects a `date`, `datetime` or `time` got {dtype}"
290
);
291
IS::Strptime(dtype, strptime_options)
292
},
293
S::Split(v) => IS::Split(v),
294
#[cfg(feature = "dtype-decimal")]
295
S::ToDecimal { scale } => IS::ToDecimal { scale },
296
#[cfg(feature = "nightly")]
297
S::Titlecase => IS::Titlecase,
298
S::Uppercase => IS::Uppercase,
299
#[cfg(feature = "string_pad")]
300
S::ZFill => IS::ZFill,
301
#[cfg(feature = "find_many")]
302
S::ContainsAny {
303
ascii_case_insensitive,
304
} => IS::ContainsAny {
305
ascii_case_insensitive,
306
},
307
#[cfg(feature = "find_many")]
308
S::ReplaceMany {
309
ascii_case_insensitive,
310
leftmost,
311
} => IS::ReplaceMany {
312
ascii_case_insensitive,
313
leftmost,
314
},
315
#[cfg(feature = "find_many")]
316
S::ExtractMany {
317
ascii_case_insensitive,
318
overlapping,
319
leftmost,
320
} => IS::ExtractMany {
321
ascii_case_insensitive,
322
overlapping,
323
leftmost,
324
},
325
#[cfg(feature = "find_many")]
326
S::FindMany {
327
ascii_case_insensitive,
328
overlapping,
329
leftmost,
330
} => IS::FindMany {
331
ascii_case_insensitive,
332
overlapping,
333
leftmost,
334
},
335
#[cfg(feature = "regex")]
336
S::EscapeRegex => IS::EscapeRegex,
337
})
338
},
339
#[cfg(feature = "dtype-struct")]
340
F::StructExpr(struct_function) => {
341
use {IRStructFunction as IS, StructFunction as S};
342
I::StructExpr(match struct_function {
343
S::FieldByName(pl_small_str) => IS::FieldByName(pl_small_str),
344
S::RenameFields(pl_small_strs) => IS::RenameFields(pl_small_strs),
345
S::PrefixFields(pl_small_str) => IS::PrefixFields(pl_small_str),
346
S::SuffixFields(pl_small_str) => IS::SuffixFields(pl_small_str),
347
S::SelectFields(_) => unreachable!("handled by expression expansion"),
348
#[cfg(feature = "json")]
349
S::JsonEncode => IS::JsonEncode,
350
S::MapFieldNames(f) => IS::MapFieldNames(f),
351
})
352
},
353
#[cfg(feature = "temporal")]
354
F::TemporalExpr(temporal_function) => {
355
use {IRTemporalFunction as IT, TemporalFunction as T};
356
I::TemporalExpr(match temporal_function {
357
T::Millennium => IT::Millennium,
358
T::Century => IT::Century,
359
T::Year => IT::Year,
360
T::IsLeapYear => IT::IsLeapYear,
361
T::IsoYear => IT::IsoYear,
362
T::Quarter => IT::Quarter,
363
T::Month => IT::Month,
364
T::DaysInMonth => IT::DaysInMonth,
365
T::Week => IT::Week,
366
T::WeekDay => IT::WeekDay,
367
T::Day => IT::Day,
368
T::OrdinalDay => IT::OrdinalDay,
369
T::Time => IT::Time,
370
T::Date => IT::Date,
371
T::Datetime => IT::Datetime,
372
#[cfg(feature = "dtype-duration")]
373
T::Duration(time_unit) => IT::Duration(time_unit),
374
T::Hour => IT::Hour,
375
T::Minute => IT::Minute,
376
T::Second => IT::Second,
377
T::Millisecond => IT::Millisecond,
378
T::Microsecond => IT::Microsecond,
379
T::Nanosecond => IT::Nanosecond,
380
#[cfg(feature = "dtype-duration")]
381
T::TotalDays { fractional } => IT::TotalDays { fractional },
382
#[cfg(feature = "dtype-duration")]
383
T::TotalHours { fractional } => IT::TotalHours { fractional },
384
#[cfg(feature = "dtype-duration")]
385
T::TotalMinutes { fractional } => IT::TotalMinutes { fractional },
386
#[cfg(feature = "dtype-duration")]
387
T::TotalSeconds { fractional } => IT::TotalSeconds { fractional },
388
#[cfg(feature = "dtype-duration")]
389
T::TotalMilliseconds { fractional } => IT::TotalMilliseconds { fractional },
390
#[cfg(feature = "dtype-duration")]
391
T::TotalMicroseconds { fractional } => IT::TotalMicroseconds { fractional },
392
#[cfg(feature = "dtype-duration")]
393
T::TotalNanoseconds { fractional } => IT::TotalNanoseconds { fractional },
394
T::ToString(v) => IT::ToString(v),
395
T::CastTimeUnit(time_unit) => IT::CastTimeUnit(time_unit),
396
T::WithTimeUnit(time_unit) => IT::WithTimeUnit(time_unit),
397
#[cfg(feature = "timezones")]
398
T::ConvertTimeZone(time_zone) => IT::ConvertTimeZone(time_zone),
399
T::TimeStamp(time_unit) => IT::TimeStamp(time_unit),
400
T::Truncate => IT::Truncate,
401
#[cfg(feature = "offset_by")]
402
T::OffsetBy => IT::OffsetBy,
403
#[cfg(feature = "month_start")]
404
T::MonthStart => IT::MonthStart,
405
#[cfg(feature = "month_end")]
406
T::MonthEnd => IT::MonthEnd,
407
#[cfg(feature = "timezones")]
408
T::BaseUtcOffset => IT::BaseUtcOffset,
409
#[cfg(feature = "timezones")]
410
T::DSTOffset => IT::DSTOffset,
411
T::Round => IT::Round,
412
T::Replace => IT::Replace,
413
#[cfg(feature = "timezones")]
414
T::ReplaceTimeZone(time_zone, non_existent) => {
415
IT::ReplaceTimeZone(time_zone, non_existent)
416
},
417
T::Combine(time_unit) => IT::Combine(time_unit),
418
T::DatetimeFunction {
419
time_unit,
420
time_zone,
421
} => IT::DatetimeFunction {
422
time_unit,
423
time_zone,
424
},
425
})
426
},
427
#[cfg(feature = "bitwise")]
428
F::Bitwise(bitwise_function) => I::Bitwise(match bitwise_function {
429
BitwiseFunction::CountOnes => IRBitwiseFunction::CountOnes,
430
BitwiseFunction::CountZeros => IRBitwiseFunction::CountZeros,
431
BitwiseFunction::LeadingOnes => IRBitwiseFunction::LeadingOnes,
432
BitwiseFunction::LeadingZeros => IRBitwiseFunction::LeadingZeros,
433
BitwiseFunction::TrailingOnes => IRBitwiseFunction::TrailingOnes,
434
BitwiseFunction::TrailingZeros => IRBitwiseFunction::TrailingZeros,
435
BitwiseFunction::And => IRBitwiseFunction::And,
436
BitwiseFunction::Or => IRBitwiseFunction::Or,
437
BitwiseFunction::Xor => IRBitwiseFunction::Xor,
438
}),
439
F::Boolean(boolean_function) => {
440
use {BooleanFunction as B, IRBooleanFunction as IB};
441
I::Boolean(match boolean_function {
442
B::Any { ignore_nulls } => IB::Any { ignore_nulls },
443
B::All { ignore_nulls } => IB::All { ignore_nulls },
444
B::IsNull => IB::IsNull,
445
B::IsNotNull => IB::IsNotNull,
446
B::IsFinite => IB::IsFinite,
447
B::IsInfinite => IB::IsInfinite,
448
B::IsNan => IB::IsNan,
449
B::IsNotNan => IB::IsNotNan,
450
#[cfg(feature = "is_first_distinct")]
451
B::IsFirstDistinct => IB::IsFirstDistinct,
452
#[cfg(feature = "is_last_distinct")]
453
B::IsLastDistinct => IB::IsLastDistinct,
454
#[cfg(feature = "is_unique")]
455
B::IsUnique => IB::IsUnique,
456
#[cfg(feature = "is_unique")]
457
B::IsDuplicated => IB::IsDuplicated,
458
#[cfg(feature = "is_between")]
459
B::IsBetween { closed } => IB::IsBetween { closed },
460
#[cfg(feature = "is_in")]
461
B::IsIn { nulls_equal } => IB::IsIn { nulls_equal },
462
#[cfg(feature = "is_close")]
463
B::IsClose {
464
abs_tol,
465
rel_tol,
466
nans_equal,
467
} => IB::IsClose {
468
abs_tol,
469
rel_tol,
470
nans_equal,
471
},
472
B::AllHorizontal => {
473
let Some(fst) = e.first() else {
474
return Ok((
475
ctx.arena.add(AExpr::Literal(Scalar::from(true).into())),
476
format_pl_smallstr!("{}", IB::AllHorizontal),
477
));
478
};
479
480
if e.len() == 1 {
481
return Ok((
482
AExprBuilder::new_from_node(fst.node())
483
.cast(DataType::Boolean, ctx.arena)
484
.node(),
485
fst.output_name().clone(),
486
));
487
}
488
489
// Convert to binary expression as the optimizer understands those.
490
// Don't exceed 128 expressions as we might stackoverflow.
491
if e.len() < 128 {
492
let mut r = AExprBuilder::new_from_node(fst.node());
493
for expr in &e[1..] {
494
r = r.logical_and(expr.node(), ctx.arena);
495
}
496
return Ok((r.node(), fst.output_name().clone()));
497
}
498
499
IB::AllHorizontal
500
},
501
B::AnyHorizontal => {
502
// This can be created by col(*).is_null() on empty dataframes.
503
let Some(fst) = e.first() else {
504
return Ok((
505
ctx.arena.add(AExpr::Literal(Scalar::from(false).into())),
506
format_pl_smallstr!("{}", IB::AnyHorizontal),
507
));
508
};
509
510
if e.len() == 1 {
511
return Ok((
512
AExprBuilder::new_from_node(fst.node())
513
.cast(DataType::Boolean, ctx.arena)
514
.node(),
515
fst.output_name().clone(),
516
));
517
}
518
519
// Convert to binary expression as the optimizer understands those.
520
// Don't exceed 128 expressions as we might stackoverflow.
521
if e.len() < 128 {
522
let mut r = AExprBuilder::new_from_node(fst.node());
523
for expr in &e[1..] {
524
r = r.logical_or(expr.node(), ctx.arena);
525
}
526
return Ok((r.node(), fst.output_name().clone()));
527
}
528
529
IB::AnyHorizontal
530
},
531
B::Not => IB::Not,
532
})
533
},
534
#[cfg(feature = "business")]
535
F::Business(business_function) => I::Business(match business_function {
536
BusinessFunction::BusinessDayCount {
537
week_mask,
538
holidays,
539
} => IRBusinessFunction::BusinessDayCount {
540
week_mask,
541
holidays,
542
},
543
BusinessFunction::AddBusinessDay {
544
week_mask,
545
holidays,
546
roll,
547
} => IRBusinessFunction::AddBusinessDay {
548
week_mask,
549
holidays,
550
roll,
551
},
552
BusinessFunction::IsBusinessDay {
553
week_mask,
554
holidays,
555
} => IRBusinessFunction::IsBusinessDay {
556
week_mask,
557
holidays,
558
},
559
}),
560
#[cfg(feature = "abs")]
561
F::Abs => I::Abs,
562
F::Negate => I::Negate,
563
#[cfg(feature = "hist")]
564
F::Hist {
565
bin_count,
566
include_category,
567
include_breakpoint,
568
} => I::Hist {
569
bin_count,
570
include_category,
571
include_breakpoint,
572
},
573
F::NullCount => I::NullCount,
574
F::Pow(pow_function) => I::Pow(match pow_function {
575
PowFunction::Generic => IRPowFunction::Generic,
576
PowFunction::Sqrt => IRPowFunction::Sqrt,
577
PowFunction::Cbrt => IRPowFunction::Cbrt,
578
}),
579
#[cfg(feature = "row_hash")]
580
F::Hash(s0, s1, s2, s3) => I::Hash(s0, s1, s2, s3),
581
#[cfg(feature = "arg_where")]
582
F::ArgWhere => I::ArgWhere,
583
#[cfg(feature = "index_of")]
584
F::IndexOf => I::IndexOf,
585
#[cfg(feature = "search_sorted")]
586
F::SearchSorted { side, descending } => I::SearchSorted { side, descending },
587
#[cfg(feature = "range")]
588
F::Range(range_function) => I::Range(match range_function {
589
RangeFunction::IntRange { step, dtype } => {
590
let dtype = dtype.into_datatype(ctx.schema)?;
591
polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `int_range`");
592
polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar stop passed to `int_range`");
593
polars_ensure!(dtype.is_integer(), SchemaMismatch: "non-integer `dtype` passed to `int_range`: '{dtype}'");
594
IRRangeFunction::IntRange { step, dtype }
595
},
596
RangeFunction::IntRanges { dtype } => {
597
let dtype = dtype.into_datatype(ctx.schema)?;
598
polars_ensure!(dtype.is_integer(), SchemaMismatch: "non-integer `dtype` passed to `int_ranges`: '{dtype}'");
599
IRRangeFunction::IntRanges { dtype }
600
},
601
RangeFunction::LinearSpace { closed } => {
602
polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `linear_space`");
603
polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar end passed to `linear_space`");
604
polars_ensure!(e[2].is_scalar(ctx.arena), ShapeMismatch: "non-scalar num_samples passed to `linear_space`");
605
IRRangeFunction::LinearSpace { closed }
606
},
607
RangeFunction::LinearSpaces {
608
closed,
609
array_width,
610
} => IRRangeFunction::LinearSpaces {
611
closed,
612
array_width,
613
},
614
#[cfg(all(feature = "range", feature = "dtype-date"))]
615
RangeFunction::DateRange {
616
interval,
617
closed,
618
arg_type,
619
} => {
620
use DateRangeArgs::*;
621
let arg_names = match arg_type {
622
StartEndSamples => vec!["start", "end", "num_samples"],
623
StartEndInterval => vec!["start", "end"],
624
StartIntervalSamples => vec!["start", "num_samples"],
625
EndIntervalSamples => vec!["end", "num_samples"],
626
};
627
for (idx, &name) in arg_names.iter().enumerate() {
628
polars_ensure!(e[idx].is_scalar(ctx.arena), ShapeMismatch: "non-scalar {name} passed to `date_range`");
629
}
630
IRRangeFunction::DateRange {
631
interval,
632
closed,
633
arg_type,
634
}
635
},
636
#[cfg(all(feature = "range", feature = "dtype-date"))]
637
RangeFunction::DateRanges {
638
interval,
639
closed,
640
arg_type,
641
} => IRRangeFunction::DateRanges {
642
interval,
643
closed,
644
arg_type,
645
},
646
#[cfg(all(feature = "range", feature = "dtype-datetime"))]
647
RangeFunction::DatetimeRange {
648
interval,
649
closed,
650
time_unit,
651
time_zone,
652
arg_type,
653
} => {
654
use DateRangeArgs::*;
655
let arg_names = match arg_type {
656
StartEndSamples => vec!["start", "end", "num_samples"],
657
StartEndInterval => vec!["start", "end"],
658
StartIntervalSamples => vec!["start", "num_samples"],
659
EndIntervalSamples => vec!["end", "num_samples"],
660
};
661
for (idx, &name) in arg_names.iter().enumerate() {
662
polars_ensure!(e[idx].is_scalar(ctx.arena), ShapeMismatch: "non-scalar {name} passed to `datetime_range`");
663
}
664
IRRangeFunction::DatetimeRange {
665
interval,
666
closed,
667
time_unit,
668
time_zone,
669
arg_type,
670
}
671
},
672
#[cfg(all(feature = "range", feature = "dtype-datetime"))]
673
RangeFunction::DatetimeRanges {
674
interval,
675
closed,
676
time_unit,
677
time_zone,
678
arg_type,
679
} => IRRangeFunction::DatetimeRanges {
680
interval,
681
closed,
682
time_unit,
683
time_zone,
684
arg_type,
685
},
686
#[cfg(all(feature = "range", feature = "dtype-time"))]
687
RangeFunction::TimeRange { interval, closed } => {
688
polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `time_range`");
689
polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar end passed to `time_range`");
690
IRRangeFunction::TimeRange { interval, closed }
691
},
692
#[cfg(all(feature = "range", feature = "dtype-time"))]
693
RangeFunction::TimeRanges { interval, closed } => {
694
IRRangeFunction::TimeRanges { interval, closed }
695
},
696
}),
697
#[cfg(feature = "trigonometry")]
698
F::Trigonometry(trigonometric_function) => {
699
use {IRTrigonometricFunction as IT, TrigonometricFunction as T};
700
I::Trigonometry(match trigonometric_function {
701
T::Cos => IT::Cos,
702
T::Cot => IT::Cot,
703
T::Sin => IT::Sin,
704
T::Tan => IT::Tan,
705
T::ArcCos => IT::ArcCos,
706
T::ArcSin => IT::ArcSin,
707
T::ArcTan => IT::ArcTan,
708
T::Cosh => IT::Cosh,
709
T::Sinh => IT::Sinh,
710
T::Tanh => IT::Tanh,
711
T::ArcCosh => IT::ArcCosh,
712
T::ArcSinh => IT::ArcSinh,
713
T::ArcTanh => IT::ArcTanh,
714
T::Degrees => IT::Degrees,
715
T::Radians => IT::Radians,
716
})
717
},
718
#[cfg(feature = "trigonometry")]
719
F::Atan2 => I::Atan2,
720
#[cfg(feature = "sign")]
721
F::Sign => I::Sign,
722
F::FillNull => I::FillNull,
723
F::FillNullWithStrategy(fill_null_strategy) => I::FillNullWithStrategy(fill_null_strategy),
724
#[cfg(feature = "rolling_window")]
725
F::RollingExpr { function, options } => {
726
use RollingFunction as R;
727
use aexpr::IRRollingFunction as IR;
728
729
I::RollingExpr {
730
function: match function {
731
R::Min => IR::Min,
732
R::Max => IR::Max,
733
R::Mean => IR::Mean,
734
R::Sum => IR::Sum,
735
R::Quantile => IR::Quantile,
736
R::Var => IR::Var,
737
R::Std => IR::Std,
738
R::Rank => IR::Rank,
739
#[cfg(feature = "moment")]
740
R::Skew => IR::Skew,
741
#[cfg(feature = "moment")]
742
R::Kurtosis => IR::Kurtosis,
743
#[cfg(feature = "cov")]
744
R::CorrCov {
745
corr_cov_options,
746
is_corr,
747
} => IR::CorrCov {
748
corr_cov_options,
749
is_corr,
750
},
751
R::Map(f) => IR::Map(f),
752
},
753
options,
754
}
755
},
756
#[cfg(feature = "rolling_window_by")]
757
F::RollingExprBy {
758
function_by,
759
options,
760
} => {
761
use RollingFunctionBy as R;
762
use aexpr::IRRollingFunctionBy as IR;
763
764
I::RollingExprBy {
765
function_by: match function_by {
766
R::MinBy => IR::MinBy,
767
R::MaxBy => IR::MaxBy,
768
R::MeanBy => IR::MeanBy,
769
R::SumBy => IR::SumBy,
770
R::QuantileBy => IR::QuantileBy,
771
R::VarBy => IR::VarBy,
772
R::StdBy => IR::StdBy,
773
R::RankBy => IR::RankBy,
774
},
775
options,
776
}
777
},
778
F::Rechunk => I::Rechunk,
779
F::Append { upcast } => I::Append { upcast },
780
F::ShiftAndFill => {
781
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
782
polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'fill_value' must be a scalar value");
783
I::ShiftAndFill
784
},
785
F::Shift => {
786
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
787
I::Shift
788
},
789
F::DropNans => I::DropNans,
790
F::DropNulls => I::DropNulls,
791
#[cfg(feature = "mode")]
792
F::Mode { maintain_order } => I::Mode { maintain_order },
793
#[cfg(feature = "moment")]
794
F::Skew(v) => I::Skew(v),
795
#[cfg(feature = "moment")]
796
F::Kurtosis(l, r) => I::Kurtosis(l, r),
797
#[cfg(feature = "dtype-array")]
798
F::Reshape(reshape_dimensions) => I::Reshape(reshape_dimensions),
799
#[cfg(feature = "repeat_by")]
800
F::RepeatBy => I::RepeatBy,
801
F::ArgUnique => I::ArgUnique,
802
F::ArgMin => I::ArgMin,
803
F::ArgMax => I::ArgMax,
804
F::ArgSort {
805
descending,
806
nulls_last,
807
} => I::ArgSort {
808
descending,
809
nulls_last,
810
},
811
F::MinBy => I::MinBy,
812
F::MaxBy => I::MaxBy,
813
F::Product => I::Product,
814
#[cfg(feature = "rank")]
815
F::Rank { options, seed } => I::Rank { options, seed },
816
F::Repeat => {
817
polars_ensure!(&e[0].is_scalar(ctx.arena), ShapeMismatch: "'value' must be a scalar value");
818
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
819
I::Repeat
820
},
821
#[cfg(feature = "round_series")]
822
F::Clip { has_min, has_max } => I::Clip { has_min, has_max },
823
#[cfg(feature = "dtype-struct")]
824
F::AsStruct => I::AsStruct,
825
#[cfg(feature = "top_k")]
826
F::TopK { descending } => I::TopK { descending },
827
#[cfg(feature = "top_k")]
828
F::TopKBy { descending } => I::TopKBy { descending },
829
#[cfg(feature = "cum_agg")]
830
F::CumCount { reverse } => I::CumCount { reverse },
831
#[cfg(feature = "cum_agg")]
832
F::CumSum { reverse } => I::CumSum { reverse },
833
#[cfg(feature = "cum_agg")]
834
F::CumProd { reverse } => I::CumProd { reverse },
835
#[cfg(feature = "cum_agg")]
836
F::CumMin { reverse } => I::CumMin { reverse },
837
#[cfg(feature = "cum_agg")]
838
F::CumMax { reverse } => I::CumMax { reverse },
839
F::Reverse => I::Reverse,
840
#[cfg(feature = "dtype-struct")]
841
F::ValueCounts {
842
sort,
843
parallel,
844
name,
845
normalize,
846
} => I::ValueCounts {
847
sort,
848
parallel,
849
name,
850
normalize,
851
},
852
#[cfg(feature = "unique_counts")]
853
F::UniqueCounts => I::UniqueCounts,
854
#[cfg(feature = "approx_unique")]
855
F::ApproxNUnique => I::ApproxNUnique,
856
F::Coalesce => I::Coalesce,
857
#[cfg(feature = "diff")]
858
F::Diff(n) => {
859
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
860
I::Diff(n)
861
},
862
#[cfg(feature = "pct_change")]
863
F::PctChange => I::PctChange,
864
#[cfg(feature = "interpolate")]
865
F::Interpolate(interpolation_method) => I::Interpolate(interpolation_method),
866
#[cfg(feature = "interpolate_by")]
867
F::InterpolateBy => I::InterpolateBy,
868
#[cfg(feature = "log")]
869
F::Entropy { base, normalize } => I::Entropy { base, normalize },
870
#[cfg(feature = "log")]
871
F::Log => I::Log,
872
#[cfg(feature = "log")]
873
F::Log1p => I::Log1p,
874
#[cfg(feature = "log")]
875
F::Exp => I::Exp,
876
F::Unique(v) => I::Unique(v),
877
#[cfg(feature = "round_series")]
878
F::Round { decimals, mode } => I::Round { decimals, mode },
879
#[cfg(feature = "round_series")]
880
F::RoundSF { digits } => I::RoundSF { digits },
881
#[cfg(feature = "round_series")]
882
F::Truncate { decimals } => I::Truncate { decimals },
883
#[cfg(feature = "round_series")]
884
F::Floor => I::Floor,
885
#[cfg(feature = "round_series")]
886
F::Ceil => I::Ceil,
887
F::UpperBound => {
888
let field = e[0].field(ctx.schema, ctx.arena)?;
889
return Ok((
890
ctx.arena
891
.add(AExpr::Literal(field.dtype.to_physical().max()?.into())),
892
field.name,
893
));
894
},
895
F::LowerBound => {
896
let field = e[0].field(ctx.schema, ctx.arena)?;
897
return Ok((
898
ctx.arena
899
.add(AExpr::Literal(field.dtype.to_physical().min()?.into())),
900
field.name,
901
));
902
},
903
F::ConcatExpr(v) => I::ConcatExpr(v),
904
#[cfg(feature = "cov")]
905
F::Correlation { method } => {
906
use {CorrelationMethod as C, IRCorrelationMethod as IC};
907
I::Correlation {
908
method: match method {
909
C::Pearson => IC::Pearson,
910
#[cfg(all(feature = "rank", feature = "propagate_nans"))]
911
C::SpearmanRank(v) => IC::SpearmanRank(v),
912
C::Covariance(v) => IC::Covariance(v),
913
},
914
}
915
},
916
#[cfg(feature = "peaks")]
917
F::PeakMin => I::PeakMin,
918
#[cfg(feature = "peaks")]
919
F::PeakMax => I::PeakMax,
920
#[cfg(feature = "cutqcut")]
921
F::Cut {
922
breaks,
923
labels,
924
left_closed,
925
include_breaks,
926
} => I::Cut {
927
breaks,
928
labels,
929
left_closed,
930
include_breaks,
931
},
932
#[cfg(feature = "cutqcut")]
933
F::QCut {
934
probs,
935
labels,
936
left_closed,
937
allow_duplicates,
938
include_breaks,
939
} => I::QCut {
940
probs,
941
labels,
942
left_closed,
943
allow_duplicates,
944
include_breaks,
945
},
946
#[cfg(feature = "rle")]
947
F::RLE => I::RLE,
948
#[cfg(feature = "rle")]
949
F::RLEID => I::RLEID,
950
F::ToPhysical => I::ToPhysical,
951
#[cfg(feature = "random")]
952
F::Random { method, seed } => {
953
use {IRRandomMethod as IR, RandomMethod as R};
954
I::Random {
955
method: match method {
956
R::Shuffle => IR::Shuffle,
957
R::Sample {
958
is_fraction,
959
with_replacement,
960
shuffle,
961
} => IR::Sample {
962
is_fraction,
963
with_replacement,
964
shuffle,
965
},
966
},
967
seed,
968
}
969
},
970
F::SetSortedFlag(is_sorted) => I::SetSortedFlag(is_sorted),
971
#[cfg(feature = "ffi_plugin")]
972
F::FfiPlugin {
973
flags,
974
lib,
975
symbol,
976
kwargs,
977
} => I::FfiPlugin {
978
flags,
979
lib,
980
symbol,
981
kwargs,
982
},
983
984
F::FoldHorizontal {
985
callback,
986
returns_scalar,
987
return_dtype,
988
} => I::FoldHorizontal {
989
callback,
990
returns_scalar,
991
return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,
992
},
993
F::ReduceHorizontal {
994
callback,
995
returns_scalar,
996
return_dtype,
997
} => I::ReduceHorizontal {
998
callback,
999
returns_scalar,
1000
return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,
1001
},
1002
#[cfg(feature = "dtype-struct")]
1003
F::CumReduceHorizontal {
1004
callback,
1005
returns_scalar,
1006
return_dtype,
1007
} => I::CumReduceHorizontal {
1008
callback,
1009
returns_scalar,
1010
return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,
1011
},
1012
#[cfg(feature = "dtype-struct")]
1013
F::CumFoldHorizontal {
1014
callback,
1015
returns_scalar,
1016
return_dtype,
1017
include_init,
1018
} => I::CumFoldHorizontal {
1019
callback,
1020
returns_scalar,
1021
return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,
1022
include_init,
1023
},
1024
1025
F::MaxHorizontal => I::MaxHorizontal,
1026
F::MinHorizontal => I::MinHorizontal,
1027
F::SumHorizontal { ignore_nulls } => I::SumHorizontal { ignore_nulls },
1028
F::MeanHorizontal { ignore_nulls } => I::MeanHorizontal { ignore_nulls },
1029
#[cfg(feature = "ewma")]
1030
F::EwmMean { options } => I::EwmMean { options },
1031
#[cfg(feature = "ewma_by")]
1032
F::EwmMeanBy { half_life } => I::EwmMeanBy { half_life },
1033
#[cfg(feature = "ewma")]
1034
F::EwmStd { options } => I::EwmStd { options },
1035
#[cfg(feature = "ewma")]
1036
F::EwmVar { options } => I::EwmVar { options },
1037
#[cfg(feature = "replace")]
1038
F::Replace => I::Replace,
1039
#[cfg(feature = "replace")]
1040
F::ReplaceStrict { return_dtype } => I::ReplaceStrict {
1041
return_dtype: match return_dtype {
1042
Some(dtype) => Some(dtype.into_datatype(ctx.schema)?),
1043
None => None,
1044
},
1045
},
1046
F::GatherEvery { n, offset } => I::GatherEvery { n, offset },
1047
#[cfg(feature = "reinterpret")]
1048
F::Reinterpret(v) => I::Reinterpret(v),
1049
F::ExtendConstant => {
1050
polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'value' must be a scalar value");
1051
polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");
1052
I::ExtendConstant
1053
},
1054
1055
F::RowEncode(v) => {
1056
let dts = e
1057
.iter()
1058
.map(|e| Ok(e.dtype(ctx.schema, ctx.arena)?.clone()))
1059
.collect::<PolarsResult<Vec<_>>>()?;
1060
I::RowEncode(dts, v)
1061
},
1062
#[cfg(feature = "dtype-struct")]
1063
F::RowDecode(fs, v) => I::RowDecode(
1064
fs.into_iter()
1065
.map(|(name, dt_expr)| Ok(Field::new(name, dt_expr.into_datatype(ctx.schema)?)))
1066
.collect::<PolarsResult<Vec<_>>>()?,
1067
v,
1068
),
1069
};
1070
1071
let mut options = ir_function.function_options();
1072
if set_elementwise {
1073
options.set_elementwise();
1074
}
1075
1076
// Handles special case functions like `struct.field`.
1077
let output_name = match ir_function.output_name().and_then(|v| v.into_inner()) {
1078
Some(name) => name,
1079
None if e.is_empty() => format_pl_smallstr!("{}", &ir_function),
1080
None => e[0].output_name().clone(),
1081
};
1082
1083
let ae_function = AExpr::Function {
1084
input: e,
1085
function: ir_function,
1086
options,
1087
};
1088
Ok((ctx.arena.add(ae_function), output_name))
1089
}
1090
1091