Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/mod.rs
8383 views
1
#[cfg(feature = "dtype-array")]
2
mod array;
3
mod binary;
4
#[cfg(feature = "bitwise")]
5
mod bitwise;
6
mod boolean;
7
#[cfg(feature = "business")]
8
mod business;
9
#[cfg(feature = "dtype-categorical")]
10
mod cat;
11
#[cfg(feature = "cov")]
12
mod correlation;
13
#[cfg(feature = "cum_agg")]
14
mod cum;
15
#[cfg(feature = "temporal")]
16
mod datetime;
17
#[cfg(feature = "dtype-extension")]
18
mod extension;
19
#[cfg(feature = "fused")]
20
mod fused;
21
mod list;
22
#[cfg(feature = "ffi_plugin")]
23
pub mod plugin;
24
mod pow;
25
#[cfg(feature = "random")]
26
mod random;
27
#[cfg(feature = "range")]
28
mod range;
29
#[cfg(feature = "rolling_window")]
30
mod rolling;
31
#[cfg(feature = "rolling_window_by")]
32
mod rolling_by;
33
mod row_encode;
34
pub(super) mod schema;
35
#[cfg(feature = "strings")]
36
mod strings;
37
#[cfg(feature = "dtype-struct")]
38
mod struct_;
39
#[cfg(feature = "trigonometry")]
40
mod trigonometry;
41
42
use std::fmt::{Display, Formatter};
43
use std::hash::{Hash, Hasher};
44
45
#[cfg(feature = "dtype-array")]
46
pub use array::IRArrayFunction;
47
#[cfg(feature = "cov")]
48
pub use correlation::IRCorrelationMethod;
49
#[cfg(feature = "fused")]
50
pub use fused::FusedOperator;
51
pub use list::IRListFunction;
52
pub use polars_core::datatypes::ReshapeDimension;
53
use polars_core::prelude::*;
54
use polars_core::series::IsSorted;
55
use polars_core::series::ops::NullBehavior;
56
use polars_core::utils::SuperTypeFlags;
57
#[cfg(feature = "random")]
58
pub use random::IRRandomMethod;
59
use schema::FieldsMapper;
60
61
pub use self::binary::IRBinaryFunction;
62
#[cfg(feature = "bitwise")]
63
pub use self::bitwise::IRBitwiseFunction;
64
pub use self::boolean::IRBooleanFunction;
65
#[cfg(feature = "business")]
66
pub use self::business::IRBusinessFunction;
67
#[cfg(feature = "dtype-categorical")]
68
pub use self::cat::IRCategoricalFunction;
69
#[cfg(feature = "temporal")]
70
pub use self::datetime::IRTemporalFunction;
71
#[cfg(feature = "dtype-extension")]
72
pub use self::extension::IRExtensionFunction;
73
pub use self::pow::IRPowFunction;
74
#[cfg(feature = "range")]
75
pub use self::range::IRRangeFunction;
76
#[cfg(feature = "rolling_window")]
77
pub use self::rolling::IRRollingFunction;
78
#[cfg(feature = "rolling_window_by")]
79
pub use self::rolling_by::IRRollingFunctionBy;
80
pub use self::row_encode::RowEncodingVariant;
81
#[cfg(feature = "strings")]
82
pub use self::strings::IRStringFunction;
83
#[cfg(all(feature = "strings", feature = "regex", feature = "timezones"))]
84
pub use self::strings::TZ_AWARE_RE;
85
#[cfg(feature = "dtype-struct")]
86
pub use self::struct_::IRStructFunction;
87
#[cfg(feature = "trigonometry")]
88
pub use self::trigonometry::IRTrigonometricFunction;
89
use super::*;
90
91
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
92
#[derive(Clone, PartialEq, Debug)]
93
pub enum IRFunctionExpr {
94
// Namespaces
95
#[cfg(feature = "dtype-array")]
96
ArrayExpr(IRArrayFunction),
97
BinaryExpr(IRBinaryFunction),
98
#[cfg(feature = "dtype-categorical")]
99
Categorical(IRCategoricalFunction),
100
#[cfg(feature = "dtype-extension")]
101
Extension(IRExtensionFunction),
102
ListExpr(IRListFunction),
103
#[cfg(feature = "strings")]
104
StringExpr(IRStringFunction),
105
#[cfg(feature = "dtype-struct")]
106
StructExpr(IRStructFunction),
107
#[cfg(feature = "temporal")]
108
TemporalExpr(IRTemporalFunction),
109
#[cfg(feature = "bitwise")]
110
Bitwise(IRBitwiseFunction),
111
112
// Other expressions
113
Boolean(IRBooleanFunction),
114
#[cfg(feature = "business")]
115
Business(IRBusinessFunction),
116
#[cfg(feature = "abs")]
117
Abs,
118
Negate,
119
#[cfg(feature = "hist")]
120
Hist {
121
bin_count: Option<usize>,
122
include_category: bool,
123
include_breakpoint: bool,
124
},
125
NullCount,
126
Pow(IRPowFunction),
127
#[cfg(feature = "row_hash")]
128
Hash(u64, u64, u64, u64),
129
#[cfg(feature = "arg_where")]
130
ArgWhere,
131
#[cfg(feature = "index_of")]
132
IndexOf,
133
#[cfg(feature = "search_sorted")]
134
SearchSorted {
135
side: SearchSortedSide,
136
descending: bool,
137
},
138
#[cfg(feature = "range")]
139
Range(IRRangeFunction),
140
#[cfg(feature = "trigonometry")]
141
Trigonometry(IRTrigonometricFunction),
142
#[cfg(feature = "trigonometry")]
143
Atan2,
144
#[cfg(feature = "sign")]
145
Sign,
146
FillNull,
147
FillNullWithStrategy(FillNullStrategy),
148
#[cfg(feature = "rolling_window")]
149
RollingExpr {
150
function: IRRollingFunction,
151
options: RollingOptionsFixedWindow,
152
},
153
#[cfg(feature = "rolling_window_by")]
154
RollingExprBy {
155
function_by: IRRollingFunctionBy,
156
options: RollingOptionsDynamicWindow,
157
},
158
Rechunk,
159
Append {
160
upcast: bool,
161
},
162
ShiftAndFill,
163
Shift,
164
DropNans,
165
DropNulls,
166
#[cfg(feature = "mode")]
167
Mode {
168
maintain_order: bool,
169
},
170
#[cfg(feature = "moment")]
171
Skew(bool),
172
#[cfg(feature = "moment")]
173
Kurtosis(bool, bool),
174
#[cfg(feature = "dtype-array")]
175
Reshape(Vec<ReshapeDimension>),
176
#[cfg(feature = "repeat_by")]
177
RepeatBy,
178
ArgUnique,
179
ArgMin,
180
ArgMax,
181
ArgSort {
182
descending: bool,
183
nulls_last: bool,
184
},
185
MinBy,
186
MaxBy,
187
Product,
188
#[cfg(feature = "rank")]
189
Rank {
190
options: RankOptions,
191
seed: Option<u64>,
192
},
193
Repeat,
194
#[cfg(feature = "round_series")]
195
Clip {
196
has_min: bool,
197
has_max: bool,
198
},
199
#[cfg(feature = "dtype-struct")]
200
AsStruct,
201
#[cfg(feature = "top_k")]
202
TopK {
203
descending: bool,
204
},
205
#[cfg(feature = "top_k")]
206
TopKBy {
207
descending: Vec<bool>,
208
},
209
#[cfg(feature = "cum_agg")]
210
CumCount {
211
reverse: bool,
212
},
213
#[cfg(feature = "cum_agg")]
214
CumSum {
215
reverse: bool,
216
},
217
#[cfg(feature = "cum_agg")]
218
CumProd {
219
reverse: bool,
220
},
221
#[cfg(feature = "cum_agg")]
222
CumMin {
223
reverse: bool,
224
},
225
#[cfg(feature = "cum_agg")]
226
CumMax {
227
reverse: bool,
228
},
229
Reverse,
230
#[cfg(feature = "dtype-struct")]
231
ValueCounts {
232
sort: bool,
233
parallel: bool,
234
name: PlSmallStr,
235
normalize: bool,
236
},
237
#[cfg(feature = "unique_counts")]
238
UniqueCounts,
239
#[cfg(feature = "approx_unique")]
240
ApproxNUnique,
241
Coalesce,
242
#[cfg(feature = "diff")]
243
Diff(NullBehavior),
244
#[cfg(feature = "pct_change")]
245
PctChange,
246
#[cfg(feature = "interpolate")]
247
Interpolate(InterpolationMethod),
248
#[cfg(feature = "interpolate_by")]
249
InterpolateBy,
250
#[cfg(feature = "log")]
251
Entropy {
252
base: f64,
253
normalize: bool,
254
},
255
#[cfg(feature = "log")]
256
Log,
257
#[cfg(feature = "log")]
258
Log1p,
259
#[cfg(feature = "log")]
260
Exp,
261
Unique(/* maintain_order */ bool),
262
#[cfg(feature = "round_series")]
263
Round {
264
decimals: u32,
265
mode: RoundMode,
266
},
267
#[cfg(feature = "round_series")]
268
RoundSF {
269
digits: i32,
270
},
271
#[cfg(feature = "round_series")]
272
Floor,
273
#[cfg(feature = "round_series")]
274
Ceil,
275
#[cfg(feature = "fused")]
276
Fused(fused::FusedOperator),
277
ConcatExpr(bool),
278
#[cfg(feature = "cov")]
279
Correlation {
280
method: correlation::IRCorrelationMethod,
281
},
282
#[cfg(feature = "peaks")]
283
PeakMin,
284
#[cfg(feature = "peaks")]
285
PeakMax,
286
#[cfg(feature = "cutqcut")]
287
Cut {
288
breaks: Vec<f64>,
289
labels: Option<Vec<PlSmallStr>>,
290
left_closed: bool,
291
include_breaks: bool,
292
},
293
#[cfg(feature = "cutqcut")]
294
QCut {
295
probs: Vec<f64>,
296
labels: Option<Vec<PlSmallStr>>,
297
left_closed: bool,
298
allow_duplicates: bool,
299
include_breaks: bool,
300
},
301
#[cfg(feature = "rle")]
302
RLE,
303
#[cfg(feature = "rle")]
304
RLEID,
305
ToPhysical,
306
#[cfg(feature = "random")]
307
Random {
308
method: IRRandomMethod,
309
seed: Option<u64>,
310
},
311
SetSortedFlag(IsSorted),
312
#[cfg(feature = "ffi_plugin")]
313
/// Creating this node is unsafe
314
/// This will lead to calls over FFI.
315
FfiPlugin {
316
flags: FunctionOptions,
317
/// Shared library.
318
lib: PlSmallStr,
319
/// Identifier in the shared lib.
320
symbol: PlSmallStr,
321
/// Pickle serialized keyword arguments.
322
kwargs: Arc<[u8]>,
323
},
324
325
FoldHorizontal {
326
callback: PlanCallback<(Series, Series), Series>,
327
returns_scalar: bool,
328
return_dtype: Option<DataType>,
329
},
330
ReduceHorizontal {
331
callback: PlanCallback<(Series, Series), Series>,
332
returns_scalar: bool,
333
return_dtype: Option<DataType>,
334
},
335
#[cfg(feature = "dtype-struct")]
336
CumReduceHorizontal {
337
callback: PlanCallback<(Series, Series), Series>,
338
returns_scalar: bool,
339
return_dtype: Option<DataType>,
340
},
341
#[cfg(feature = "dtype-struct")]
342
CumFoldHorizontal {
343
callback: PlanCallback<(Series, Series), Series>,
344
returns_scalar: bool,
345
return_dtype: Option<DataType>,
346
include_init: bool,
347
},
348
349
MaxHorizontal,
350
MinHorizontal,
351
SumHorizontal {
352
ignore_nulls: bool,
353
},
354
MeanHorizontal {
355
ignore_nulls: bool,
356
},
357
#[cfg(feature = "ewma")]
358
EwmMean {
359
options: EWMOptions,
360
},
361
#[cfg(feature = "ewma_by")]
362
EwmMeanBy {
363
half_life: Duration,
364
},
365
#[cfg(feature = "ewma")]
366
EwmStd {
367
options: EWMOptions,
368
},
369
#[cfg(feature = "ewma")]
370
EwmVar {
371
options: EWMOptions,
372
},
373
#[cfg(feature = "replace")]
374
Replace,
375
#[cfg(feature = "replace")]
376
ReplaceStrict {
377
return_dtype: Option<DataType>,
378
},
379
GatherEvery {
380
n: usize,
381
offset: usize,
382
},
383
#[cfg(feature = "reinterpret")]
384
Reinterpret(bool),
385
ExtendConstant,
386
387
RowEncode(Vec<DataType>, RowEncodingVariant),
388
#[cfg(feature = "dtype-struct")]
389
RowDecode(Vec<Field>, RowEncodingVariant),
390
}
391
392
impl Hash for IRFunctionExpr {
393
fn hash<H: Hasher>(&self, state: &mut H) {
394
std::mem::discriminant(self).hash(state);
395
use IRFunctionExpr::*;
396
match self {
397
// Namespaces
398
#[cfg(feature = "dtype-array")]
399
ArrayExpr(f) => f.hash(state),
400
BinaryExpr(f) => f.hash(state),
401
#[cfg(feature = "dtype-categorical")]
402
Categorical(f) => f.hash(state),
403
#[cfg(feature = "dtype-extension")]
404
Extension(f) => f.hash(state),
405
ListExpr(f) => f.hash(state),
406
#[cfg(feature = "strings")]
407
StringExpr(f) => f.hash(state),
408
#[cfg(feature = "dtype-struct")]
409
StructExpr(f) => f.hash(state),
410
#[cfg(feature = "temporal")]
411
TemporalExpr(f) => f.hash(state),
412
#[cfg(feature = "bitwise")]
413
Bitwise(f) => f.hash(state),
414
415
// Other expressions
416
Boolean(f) => f.hash(state),
417
#[cfg(feature = "business")]
418
Business(f) => f.hash(state),
419
Pow(f) => f.hash(state),
420
#[cfg(feature = "index_of")]
421
IndexOf => {},
422
#[cfg(feature = "search_sorted")]
423
SearchSorted { side, descending } => {
424
side.hash(state);
425
descending.hash(state);
426
},
427
#[cfg(feature = "random")]
428
Random { method, .. } => method.hash(state),
429
#[cfg(feature = "cov")]
430
Correlation { method, .. } => method.hash(state),
431
#[cfg(feature = "range")]
432
Range(f) => f.hash(state),
433
#[cfg(feature = "trigonometry")]
434
Trigonometry(f) => f.hash(state),
435
#[cfg(feature = "fused")]
436
Fused(f) => f.hash(state),
437
#[cfg(feature = "diff")]
438
Diff(null_behavior) => null_behavior.hash(state),
439
#[cfg(feature = "interpolate")]
440
Interpolate(f) => f.hash(state),
441
#[cfg(feature = "interpolate_by")]
442
InterpolateBy => {},
443
#[cfg(feature = "ffi_plugin")]
444
FfiPlugin {
445
flags: _,
446
lib,
447
symbol,
448
kwargs,
449
} => {
450
kwargs.hash(state);
451
lib.hash(state);
452
symbol.hash(state);
453
},
454
455
FoldHorizontal {
456
callback,
457
returns_scalar,
458
return_dtype,
459
}
460
| ReduceHorizontal {
461
callback,
462
returns_scalar,
463
return_dtype,
464
} => {
465
callback.hash(state);
466
returns_scalar.hash(state);
467
return_dtype.hash(state);
468
},
469
#[cfg(feature = "dtype-struct")]
470
CumReduceHorizontal {
471
callback,
472
returns_scalar,
473
return_dtype,
474
} => {
475
callback.hash(state);
476
returns_scalar.hash(state);
477
return_dtype.hash(state);
478
},
479
#[cfg(feature = "dtype-struct")]
480
CumFoldHorizontal {
481
callback,
482
returns_scalar,
483
return_dtype,
484
include_init,
485
} => {
486
callback.hash(state);
487
returns_scalar.hash(state);
488
return_dtype.hash(state);
489
include_init.hash(state);
490
},
491
492
SumHorizontal { ignore_nulls } | MeanHorizontal { ignore_nulls } => {
493
ignore_nulls.hash(state)
494
},
495
MaxHorizontal | MinHorizontal | DropNans | DropNulls | Reverse | ArgUnique | ArgMin
496
| ArgMax | Product | Shift | ShiftAndFill | Rechunk | MinBy | MaxBy => {},
497
Append { upcast } => {
498
upcast.hash(state);
499
},
500
ArgSort {
501
descending,
502
nulls_last,
503
} => {
504
descending.hash(state);
505
nulls_last.hash(state);
506
},
507
#[cfg(feature = "mode")]
508
Mode { maintain_order } => {
509
maintain_order.hash(state);
510
},
511
#[cfg(feature = "abs")]
512
Abs => {},
513
Negate => {},
514
NullCount => {},
515
#[cfg(feature = "arg_where")]
516
ArgWhere => {},
517
#[cfg(feature = "trigonometry")]
518
Atan2 => {},
519
#[cfg(feature = "dtype-struct")]
520
AsStruct => {},
521
#[cfg(feature = "sign")]
522
Sign => {},
523
#[cfg(feature = "row_hash")]
524
Hash(a, b, c, d) => (a, b, c, d).hash(state),
525
FillNull => {},
526
#[cfg(feature = "rolling_window")]
527
RollingExpr { function, options } => {
528
function.hash(state);
529
options.hash(state);
530
},
531
#[cfg(feature = "rolling_window_by")]
532
RollingExprBy {
533
function_by,
534
options,
535
} => {
536
function_by.hash(state);
537
options.hash(state);
538
},
539
#[cfg(feature = "moment")]
540
Skew(a) => a.hash(state),
541
#[cfg(feature = "moment")]
542
Kurtosis(a, b) => {
543
a.hash(state);
544
b.hash(state);
545
},
546
Repeat => {},
547
#[cfg(feature = "rank")]
548
Rank { options, seed } => {
549
options.hash(state);
550
seed.hash(state);
551
},
552
#[cfg(feature = "round_series")]
553
Clip { has_min, has_max } => {
554
has_min.hash(state);
555
has_max.hash(state);
556
},
557
#[cfg(feature = "top_k")]
558
TopK { descending } => descending.hash(state),
559
#[cfg(feature = "cum_agg")]
560
CumCount { reverse } => reverse.hash(state),
561
#[cfg(feature = "cum_agg")]
562
CumSum { reverse } => reverse.hash(state),
563
#[cfg(feature = "cum_agg")]
564
CumProd { reverse } => reverse.hash(state),
565
#[cfg(feature = "cum_agg")]
566
CumMin { reverse } => reverse.hash(state),
567
#[cfg(feature = "cum_agg")]
568
CumMax { reverse } => reverse.hash(state),
569
#[cfg(feature = "dtype-struct")]
570
ValueCounts {
571
sort,
572
parallel,
573
name,
574
normalize,
575
} => {
576
sort.hash(state);
577
parallel.hash(state);
578
name.hash(state);
579
normalize.hash(state);
580
},
581
#[cfg(feature = "unique_counts")]
582
UniqueCounts => {},
583
#[cfg(feature = "approx_unique")]
584
ApproxNUnique => {},
585
Coalesce => {},
586
#[cfg(feature = "pct_change")]
587
PctChange => {},
588
#[cfg(feature = "log")]
589
Entropy { base, normalize } => {
590
base.to_bits().hash(state);
591
normalize.hash(state);
592
},
593
#[cfg(feature = "log")]
594
Log => {},
595
#[cfg(feature = "log")]
596
Log1p => {},
597
#[cfg(feature = "log")]
598
Exp => {},
599
Unique(a) => a.hash(state),
600
#[cfg(feature = "round_series")]
601
Round { decimals, mode } => {
602
decimals.hash(state);
603
mode.hash(state);
604
},
605
#[cfg(feature = "round_series")]
606
IRFunctionExpr::RoundSF { digits } => digits.hash(state),
607
#[cfg(feature = "round_series")]
608
IRFunctionExpr::Floor => {},
609
#[cfg(feature = "round_series")]
610
Ceil => {},
611
ConcatExpr(a) => a.hash(state),
612
#[cfg(feature = "peaks")]
613
PeakMin => {},
614
#[cfg(feature = "peaks")]
615
PeakMax => {},
616
#[cfg(feature = "cutqcut")]
617
Cut {
618
breaks,
619
labels,
620
left_closed,
621
include_breaks,
622
} => {
623
let slice = bytemuck::cast_slice::<_, u64>(breaks);
624
slice.hash(state);
625
labels.hash(state);
626
left_closed.hash(state);
627
include_breaks.hash(state);
628
},
629
#[cfg(feature = "dtype-array")]
630
Reshape(dims) => dims.hash(state),
631
#[cfg(feature = "repeat_by")]
632
RepeatBy => {},
633
#[cfg(feature = "cutqcut")]
634
QCut {
635
probs,
636
labels,
637
left_closed,
638
allow_duplicates,
639
include_breaks,
640
} => {
641
let slice = bytemuck::cast_slice::<_, u64>(probs);
642
slice.hash(state);
643
labels.hash(state);
644
left_closed.hash(state);
645
allow_duplicates.hash(state);
646
include_breaks.hash(state);
647
},
648
#[cfg(feature = "rle")]
649
RLE => {},
650
#[cfg(feature = "rle")]
651
RLEID => {},
652
ToPhysical => {},
653
SetSortedFlag(is_sorted) => is_sorted.hash(state),
654
#[cfg(feature = "ewma")]
655
EwmMean { options } => options.hash(state),
656
#[cfg(feature = "ewma_by")]
657
EwmMeanBy { half_life } => (half_life).hash(state),
658
#[cfg(feature = "ewma")]
659
EwmStd { options } => options.hash(state),
660
#[cfg(feature = "ewma")]
661
EwmVar { options } => options.hash(state),
662
#[cfg(feature = "hist")]
663
Hist {
664
bin_count,
665
include_category,
666
include_breakpoint,
667
} => {
668
bin_count.hash(state);
669
include_category.hash(state);
670
include_breakpoint.hash(state);
671
},
672
#[cfg(feature = "replace")]
673
Replace => {},
674
#[cfg(feature = "replace")]
675
ReplaceStrict { return_dtype } => return_dtype.hash(state),
676
FillNullWithStrategy(strategy) => strategy.hash(state),
677
GatherEvery { n, offset } => (n, offset).hash(state),
678
#[cfg(feature = "reinterpret")]
679
Reinterpret(signed) => signed.hash(state),
680
ExtendConstant => {},
681
#[cfg(feature = "top_k")]
682
TopKBy { descending } => descending.hash(state),
683
684
RowEncode(dts, variants) => {
685
dts.hash(state);
686
variants.hash(state);
687
},
688
#[cfg(feature = "dtype-struct")]
689
RowDecode(fs, variants) => {
690
fs.hash(state);
691
variants.hash(state);
692
},
693
}
694
}
695
}
696
697
impl Display for IRFunctionExpr {
698
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
699
use IRFunctionExpr::*;
700
let s = match self {
701
// Namespaces
702
#[cfg(feature = "dtype-array")]
703
ArrayExpr(func) => return write!(f, "{func}"),
704
BinaryExpr(func) => return write!(f, "{func}"),
705
#[cfg(feature = "dtype-categorical")]
706
Categorical(func) => return write!(f, "{func}"),
707
#[cfg(feature = "dtype-extension")]
708
Extension(func) => return write!(f, "{func}"),
709
ListExpr(func) => return write!(f, "{func}"),
710
#[cfg(feature = "strings")]
711
StringExpr(func) => return write!(f, "{func}"),
712
#[cfg(feature = "dtype-struct")]
713
StructExpr(func) => return write!(f, "{func}"),
714
#[cfg(feature = "temporal")]
715
TemporalExpr(func) => return write!(f, "{func}"),
716
#[cfg(feature = "bitwise")]
717
Bitwise(func) => return write!(f, "bitwise_{func}"),
718
719
// Other expressions
720
Boolean(func) => return write!(f, "{func}"),
721
#[cfg(feature = "business")]
722
Business(func) => return write!(f, "{func}"),
723
#[cfg(feature = "abs")]
724
Abs => "abs",
725
Negate => "negate",
726
NullCount => "null_count",
727
Pow(func) => return write!(f, "{func}"),
728
#[cfg(feature = "row_hash")]
729
Hash(_, _, _, _) => "hash",
730
#[cfg(feature = "arg_where")]
731
ArgWhere => "arg_where",
732
#[cfg(feature = "index_of")]
733
IndexOf => "index_of",
734
#[cfg(feature = "search_sorted")]
735
SearchSorted { .. } => "search_sorted",
736
#[cfg(feature = "range")]
737
Range(func) => return write!(f, "{func}"),
738
#[cfg(feature = "trigonometry")]
739
Trigonometry(func) => return write!(f, "{func}"),
740
#[cfg(feature = "trigonometry")]
741
Atan2 => return write!(f, "arctan2"),
742
#[cfg(feature = "sign")]
743
Sign => "sign",
744
FillNull => "fill_null",
745
#[cfg(feature = "rolling_window")]
746
RollingExpr { function, .. } => return write!(f, "{function}"),
747
#[cfg(feature = "rolling_window_by")]
748
RollingExprBy { function_by, .. } => return write!(f, "{function_by}"),
749
Rechunk => "rechunk",
750
Append { .. } => "append",
751
ShiftAndFill => "shift_and_fill",
752
DropNans => "drop_nans",
753
DropNulls => "drop_nulls",
754
#[cfg(feature = "mode")]
755
Mode { maintain_order } => {
756
if *maintain_order {
757
"mode_stable"
758
} else {
759
"mode"
760
}
761
},
762
#[cfg(feature = "moment")]
763
Skew(_) => "skew",
764
#[cfg(feature = "moment")]
765
Kurtosis(..) => "kurtosis",
766
ArgUnique => "arg_unique",
767
ArgMin => "arg_min",
768
ArgMax => "arg_max",
769
ArgSort { .. } => "arg_sort",
770
MinBy => "min_by",
771
MaxBy => "max_by",
772
Product => "product",
773
Repeat => "repeat",
774
#[cfg(feature = "rank")]
775
Rank { .. } => "rank",
776
#[cfg(feature = "round_series")]
777
Clip { has_min, has_max } => match (has_min, has_max) {
778
(true, true) => "clip",
779
(false, true) => "clip_max",
780
(true, false) => "clip_min",
781
_ => unreachable!(),
782
},
783
#[cfg(feature = "dtype-struct")]
784
AsStruct => "as_struct",
785
#[cfg(feature = "top_k")]
786
TopK { descending } => {
787
if *descending {
788
"bottom_k"
789
} else {
790
"top_k"
791
}
792
},
793
#[cfg(feature = "top_k")]
794
TopKBy { .. } => "top_k_by",
795
Shift => "shift",
796
#[cfg(feature = "cum_agg")]
797
CumCount { .. } => "cum_count",
798
#[cfg(feature = "cum_agg")]
799
CumSum { .. } => "cum_sum",
800
#[cfg(feature = "cum_agg")]
801
CumProd { .. } => "cum_prod",
802
#[cfg(feature = "cum_agg")]
803
CumMin { .. } => "cum_min",
804
#[cfg(feature = "cum_agg")]
805
CumMax { .. } => "cum_max",
806
#[cfg(feature = "dtype-struct")]
807
ValueCounts { .. } => "value_counts",
808
#[cfg(feature = "unique_counts")]
809
UniqueCounts => "unique_counts",
810
Reverse => "reverse",
811
#[cfg(feature = "approx_unique")]
812
ApproxNUnique => "approx_n_unique",
813
Coalesce => "coalesce",
814
#[cfg(feature = "diff")]
815
Diff(_) => "diff",
816
#[cfg(feature = "pct_change")]
817
PctChange => "pct_change",
818
#[cfg(feature = "interpolate")]
819
Interpolate(_) => "interpolate",
820
#[cfg(feature = "interpolate_by")]
821
InterpolateBy => "interpolate_by",
822
#[cfg(feature = "log")]
823
Entropy { .. } => "entropy",
824
#[cfg(feature = "log")]
825
Log => "log",
826
#[cfg(feature = "log")]
827
Log1p => "log1p",
828
#[cfg(feature = "log")]
829
Exp => "exp",
830
Unique(stable) => {
831
if *stable {
832
"unique_stable"
833
} else {
834
"unique"
835
}
836
},
837
#[cfg(feature = "round_series")]
838
Round { .. } => "round",
839
#[cfg(feature = "round_series")]
840
RoundSF { .. } => "round_sig_figs",
841
#[cfg(feature = "round_series")]
842
Floor => "floor",
843
#[cfg(feature = "round_series")]
844
Ceil => "ceil",
845
#[cfg(feature = "fused")]
846
Fused(fused) => return Display::fmt(fused, f),
847
ConcatExpr(_) => "concat_expr",
848
#[cfg(feature = "cov")]
849
Correlation { method, .. } => return Display::fmt(method, f),
850
#[cfg(feature = "peaks")]
851
PeakMin => "peak_min",
852
#[cfg(feature = "peaks")]
853
PeakMax => "peak_max",
854
#[cfg(feature = "cutqcut")]
855
Cut { .. } => "cut",
856
#[cfg(feature = "cutqcut")]
857
QCut { .. } => "qcut",
858
#[cfg(feature = "dtype-array")]
859
Reshape(_) => "reshape",
860
#[cfg(feature = "repeat_by")]
861
RepeatBy => "repeat_by",
862
#[cfg(feature = "rle")]
863
RLE => "rle",
864
#[cfg(feature = "rle")]
865
RLEID => "rle_id",
866
ToPhysical => "to_physical",
867
#[cfg(feature = "random")]
868
Random { method, .. } => method.into(),
869
SetSortedFlag(_) => "set_sorted",
870
#[cfg(feature = "ffi_plugin")]
871
FfiPlugin { lib, symbol, .. } => return write!(f, "{lib}:{symbol}"),
872
873
FoldHorizontal { .. } => "fold",
874
ReduceHorizontal { .. } => "reduce",
875
#[cfg(feature = "dtype-struct")]
876
CumReduceHorizontal { .. } => "cum_reduce",
877
#[cfg(feature = "dtype-struct")]
878
CumFoldHorizontal { .. } => "cum_fold",
879
880
MaxHorizontal => "max_horizontal",
881
MinHorizontal => "min_horizontal",
882
SumHorizontal { .. } => "sum_horizontal",
883
MeanHorizontal { .. } => "mean_horizontal",
884
#[cfg(feature = "ewma")]
885
EwmMean { .. } => "ewm_mean",
886
#[cfg(feature = "ewma_by")]
887
EwmMeanBy { .. } => "ewm_mean_by",
888
#[cfg(feature = "ewma")]
889
EwmStd { .. } => "ewm_std",
890
#[cfg(feature = "ewma")]
891
EwmVar { .. } => "ewm_var",
892
#[cfg(feature = "hist")]
893
Hist { .. } => "hist",
894
#[cfg(feature = "replace")]
895
Replace => "replace",
896
#[cfg(feature = "replace")]
897
ReplaceStrict { .. } => "replace_strict",
898
FillNullWithStrategy(_) => "fill_null_with_strategy",
899
GatherEvery { .. } => "gather_every",
900
#[cfg(feature = "reinterpret")]
901
Reinterpret(_) => "reinterpret",
902
ExtendConstant => "extend_constant",
903
904
RowEncode(..) => "row_encode",
905
#[cfg(feature = "dtype-struct")]
906
RowDecode(..) => "row_decode",
907
};
908
write!(f, "{s}")
909
}
910
}
911
912
#[macro_export]
913
macro_rules! wrap {
914
($e:expr) => {
915
SpecialEq::new(Arc::new($e))
916
};
917
918
($e:expr, $($args:expr),*) => {{
919
let f = move |s: &mut [Column]| {
920
$e(s, $($args),*)
921
};
922
923
SpecialEq::new(Arc::new(f))
924
}};
925
}
926
927
/// `Fn(&[Column], args)`
928
/// * all expression arguments are in the slice.
929
/// * the first element is the root expression.
930
#[macro_export]
931
macro_rules! map_as_slice {
932
($func:path) => {{
933
let f = move |s: &mut [Column]| {
934
$func(s)
935
};
936
937
SpecialEq::new(Arc::new(f))
938
}};
939
940
($func:path, $($args:expr),*) => {{
941
let f = move |s: &mut [Column]| {
942
$func(s, $($args),*)
943
};
944
945
SpecialEq::new(Arc::new(f))
946
}};
947
}
948
949
/// * `FnOnce(Series)`
950
/// * `FnOnce(Series, args)`
951
#[macro_export]
952
macro_rules! map_owned {
953
($func:path) => {{
954
let f = move |c: &mut [Column]| {
955
let c = std::mem::take(&mut c[0]);
956
$func(c)
957
};
958
959
SpecialEq::new(Arc::new(f))
960
}};
961
962
($func:path, $($args:expr),*) => {{
963
let f = move |c: &mut [Column]| {
964
let c = std::mem::take(&mut c[0]);
965
$func(c, $($args),*)
966
};
967
968
SpecialEq::new(Arc::new(f))
969
}};
970
}
971
972
/// `Fn(&Series, args)`
973
#[macro_export]
974
macro_rules! map {
975
($func:path) => {{
976
let f = move |c: &mut [Column]| {
977
let c = &c[0];
978
$func(c)
979
};
980
981
SpecialEq::new(Arc::new(f))
982
}};
983
984
($func:path, $($args:expr),*) => {{
985
let f = move |c: &mut [Column]| {
986
let c = &c[0];
987
$func(c, $($args),*)
988
};
989
990
SpecialEq::new(Arc::new(f))
991
}};
992
}
993
994
impl IRFunctionExpr {
995
pub fn function_options(&self) -> FunctionOptions {
996
use IRFunctionExpr as F;
997
match self {
998
#[cfg(feature = "dtype-array")]
999
F::ArrayExpr(e) => e.function_options(),
1000
F::BinaryExpr(e) => e.function_options(),
1001
#[cfg(feature = "dtype-categorical")]
1002
F::Categorical(e) => e.function_options(),
1003
#[cfg(feature = "dtype-extension")]
1004
F::Extension(e) => e.function_options(),
1005
F::ListExpr(e) => e.function_options(),
1006
#[cfg(feature = "strings")]
1007
F::StringExpr(e) => e.function_options(),
1008
#[cfg(feature = "dtype-struct")]
1009
F::StructExpr(e) => e.function_options(),
1010
#[cfg(feature = "temporal")]
1011
F::TemporalExpr(e) => e.function_options(),
1012
#[cfg(feature = "bitwise")]
1013
F::Bitwise(e) => e.function_options(),
1014
F::Boolean(e) => e.function_options(),
1015
#[cfg(feature = "business")]
1016
F::Business(e) => e.function_options(),
1017
F::Pow(e) => e.function_options(),
1018
#[cfg(feature = "range")]
1019
F::Range(e) => e.function_options(),
1020
#[cfg(feature = "abs")]
1021
F::Abs => FunctionOptions::elementwise(),
1022
F::Negate => FunctionOptions::elementwise(),
1023
#[cfg(feature = "hist")]
1024
F::Hist { .. } => FunctionOptions::groupwise(),
1025
F::NullCount => FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING),
1026
#[cfg(feature = "row_hash")]
1027
F::Hash(_, _, _, _) => FunctionOptions::elementwise(),
1028
#[cfg(feature = "arg_where")]
1029
F::ArgWhere => FunctionOptions::groupwise(),
1030
#[cfg(feature = "index_of")]
1031
F::IndexOf => {
1032
FunctionOptions::aggregation().with_casting_rules(CastingRules::FirstArgLossless)
1033
},
1034
#[cfg(feature = "search_sorted")]
1035
F::SearchSorted { .. } => FunctionOptions::groupwise().with_supertyping(
1036
(SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(),
1037
),
1038
#[cfg(feature = "trigonometry")]
1039
F::Trigonometry(_) => FunctionOptions::elementwise(),
1040
#[cfg(feature = "trigonometry")]
1041
F::Atan2 => FunctionOptions::elementwise(),
1042
#[cfg(feature = "sign")]
1043
F::Sign => FunctionOptions::elementwise(),
1044
F::FillNull => FunctionOptions::elementwise().with_supertyping(Default::default()),
1045
F::FillNullWithStrategy(strategy) if strategy.is_elementwise() => {
1046
FunctionOptions::elementwise()
1047
},
1048
F::FillNullWithStrategy(_) => FunctionOptions::length_preserving(),
1049
#[cfg(feature = "rolling_window")]
1050
F::RollingExpr { .. } => FunctionOptions::length_preserving(),
1051
#[cfg(feature = "rolling_window_by")]
1052
F::RollingExprBy { .. } => FunctionOptions::length_preserving(),
1053
F::Rechunk => FunctionOptions::length_preserving(),
1054
F::Append { .. } => FunctionOptions::groupwise(),
1055
F::ShiftAndFill => FunctionOptions::length_preserving(),
1056
F::Shift => FunctionOptions::length_preserving(),
1057
F::DropNans => {
1058
FunctionOptions::row_separable().flag(FunctionFlags::NON_ORDER_PRODUCING)
1059
},
1060
F::DropNulls => FunctionOptions::row_separable()
1061
.flag(FunctionFlags::ALLOW_EMPTY_INPUTS | FunctionFlags::NON_ORDER_PRODUCING),
1062
#[cfg(feature = "mode")]
1063
F::Mode { maintain_order } => FunctionOptions::groupwise().with_flags(|f| {
1064
let f = f | FunctionFlags::NON_ORDER_PRODUCING;
1065
1066
if !*maintain_order {
1067
f | FunctionFlags::NON_ORDER_OBSERVING | FunctionFlags::TERMINATES_INPUT_ORDER
1068
} else {
1069
f
1070
}
1071
}),
1072
#[cfg(feature = "moment")]
1073
F::Skew(_) => FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING),
1074
#[cfg(feature = "moment")]
1075
F::Kurtosis(_, _) => {
1076
FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING)
1077
},
1078
#[cfg(feature = "dtype-array")]
1079
F::Reshape(dims) => {
1080
if dims.len() == 1 && dims[0] == ReshapeDimension::Infer {
1081
FunctionOptions::row_separable()
1082
} else {
1083
FunctionOptions::groupwise()
1084
}
1085
},
1086
#[cfg(feature = "repeat_by")]
1087
F::RepeatBy => FunctionOptions::elementwise(),
1088
F::ArgUnique => FunctionOptions::groupwise(),
1089
F::ArgMin | F::ArgMax => FunctionOptions::aggregation(),
1090
F::ArgSort { .. } => FunctionOptions::length_preserving(),
1091
F::MinBy | F::MaxBy => FunctionOptions::aggregation(),
1092
F::Product => FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING),
1093
#[cfg(feature = "rank")]
1094
F::Rank { .. } => FunctionOptions::length_preserving(),
1095
F::Repeat => {
1096
FunctionOptions::groupwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
1097
},
1098
#[cfg(feature = "round_series")]
1099
F::Clip { .. } => FunctionOptions::elementwise(),
1100
#[cfg(feature = "dtype-struct")]
1101
F::AsStruct => FunctionOptions::elementwise().with_flags(|f| {
1102
f | FunctionFlags::PASS_NAME_TO_APPLY | FunctionFlags::INPUT_WILDCARD_EXPANSION
1103
}),
1104
#[cfg(feature = "top_k")]
1105
F::TopK { .. } => FunctionOptions::groupwise(),
1106
#[cfg(feature = "top_k")]
1107
F::TopKBy { .. } => FunctionOptions::groupwise(),
1108
#[cfg(feature = "cum_agg")]
1109
F::CumCount { .. }
1110
| F::CumSum { .. }
1111
| F::CumProd { .. }
1112
| F::CumMin { .. }
1113
| F::CumMax { .. } => FunctionOptions::length_preserving(),
1114
F::Reverse => FunctionOptions::length_preserving()
1115
.with_flags(|f| f | FunctionFlags::NON_ORDER_OBSERVING),
1116
#[cfg(feature = "dtype-struct")]
1117
F::ValueCounts { sort, .. } => FunctionOptions::groupwise().with_flags(|mut f| {
1118
if !sort {
1119
f |= FunctionFlags::TERMINATES_INPUT_ORDER | FunctionFlags::NON_ORDER_PRODUCING
1120
}
1121
f | FunctionFlags::PASS_NAME_TO_APPLY | FunctionFlags::NON_ORDER_OBSERVING
1122
}),
1123
#[cfg(feature = "unique_counts")]
1124
F::UniqueCounts => FunctionOptions::groupwise(),
1125
#[cfg(feature = "approx_unique")]
1126
F::ApproxNUnique => {
1127
FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING)
1128
},
1129
F::Coalesce => FunctionOptions::elementwise()
1130
.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION)
1131
.with_supertyping(Default::default()),
1132
#[cfg(feature = "diff")]
1133
F::Diff(NullBehavior::Drop) => FunctionOptions::groupwise(),
1134
#[cfg(feature = "diff")]
1135
F::Diff(NullBehavior::Ignore) => FunctionOptions::length_preserving(),
1136
#[cfg(feature = "pct_change")]
1137
F::PctChange => FunctionOptions::length_preserving(),
1138
#[cfg(feature = "interpolate")]
1139
F::Interpolate(_) => FunctionOptions::length_preserving(),
1140
#[cfg(feature = "interpolate_by")]
1141
F::InterpolateBy => FunctionOptions::length_preserving(),
1142
#[cfg(feature = "log")]
1143
F::Log | F::Log1p | F::Exp => FunctionOptions::elementwise(),
1144
#[cfg(feature = "log")]
1145
F::Entropy { .. } => {
1146
FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING)
1147
},
1148
F::Unique(maintain_order) => FunctionOptions::groupwise().with_flags(|f| {
1149
let f = f | FunctionFlags::NON_ORDER_PRODUCING;
1150
1151
if !*maintain_order {
1152
f | FunctionFlags::NON_ORDER_OBSERVING | FunctionFlags::TERMINATES_INPUT_ORDER
1153
} else {
1154
f
1155
}
1156
}),
1157
#[cfg(feature = "round_series")]
1158
F::Round { .. } | F::RoundSF { .. } | F::Floor | F::Ceil => {
1159
FunctionOptions::elementwise()
1160
},
1161
#[cfg(feature = "fused")]
1162
F::Fused(_) => FunctionOptions::elementwise(),
1163
F::ConcatExpr(_) => FunctionOptions::groupwise()
1164
.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION)
1165
.with_supertyping(Default::default()),
1166
#[cfg(feature = "cov")]
1167
F::Correlation { .. } => {
1168
FunctionOptions::aggregation().with_supertyping(Default::default())
1169
},
1170
#[cfg(feature = "peaks")]
1171
F::PeakMin | F::PeakMax => FunctionOptions::length_preserving(),
1172
#[cfg(feature = "cutqcut")]
1173
F::Cut { .. } | F::QCut { .. } => FunctionOptions::length_preserving()
1174
.with_flags(|f| f | FunctionFlags::PASS_NAME_TO_APPLY),
1175
#[cfg(feature = "rle")]
1176
F::RLE => FunctionOptions::groupwise(),
1177
#[cfg(feature = "rle")]
1178
F::RLEID => FunctionOptions::length_preserving(),
1179
F::ToPhysical => FunctionOptions::elementwise(),
1180
#[cfg(feature = "random")]
1181
F::Random {
1182
method: IRRandomMethod::Sample { .. },
1183
..
1184
} => FunctionOptions::groupwise(),
1185
#[cfg(feature = "random")]
1186
F::Random {
1187
method: IRRandomMethod::Shuffle,
1188
..
1189
} => FunctionOptions::length_preserving(),
1190
F::SetSortedFlag(_) => FunctionOptions::elementwise(),
1191
#[cfg(feature = "ffi_plugin")]
1192
F::FfiPlugin { flags, .. } => *flags,
1193
F::MaxHorizontal | F::MinHorizontal => FunctionOptions::elementwise().with_flags(|f| {
1194
f | FunctionFlags::INPUT_WILDCARD_EXPANSION | FunctionFlags::ALLOW_RENAME
1195
}),
1196
F::MeanHorizontal { .. } | F::SumHorizontal { .. } => FunctionOptions::elementwise()
1197
.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),
1198
1199
F::FoldHorizontal { returns_scalar, .. }
1200
| F::ReduceHorizontal { returns_scalar, .. } => FunctionOptions::groupwise()
1201
.with_flags(|mut f| {
1202
f |= FunctionFlags::INPUT_WILDCARD_EXPANSION;
1203
if *returns_scalar {
1204
f |= FunctionFlags::RETURNS_SCALAR;
1205
}
1206
f
1207
}),
1208
#[cfg(feature = "dtype-struct")]
1209
F::CumFoldHorizontal { returns_scalar, .. }
1210
| F::CumReduceHorizontal { returns_scalar, .. } => FunctionOptions::groupwise()
1211
.with_flags(|mut f| {
1212
f |= FunctionFlags::INPUT_WILDCARD_EXPANSION;
1213
if *returns_scalar {
1214
f |= FunctionFlags::RETURNS_SCALAR;
1215
}
1216
f
1217
}),
1218
#[cfg(feature = "ewma")]
1219
F::EwmMean { .. } | F::EwmStd { .. } | F::EwmVar { .. } => {
1220
FunctionOptions::length_preserving()
1221
},
1222
#[cfg(feature = "ewma_by")]
1223
F::EwmMeanBy { .. } => FunctionOptions::length_preserving(),
1224
#[cfg(feature = "replace")]
1225
F::Replace => FunctionOptions::elementwise(),
1226
#[cfg(feature = "replace")]
1227
F::ReplaceStrict { .. } => FunctionOptions::elementwise(),
1228
F::GatherEvery { .. } => FunctionOptions::groupwise(),
1229
#[cfg(feature = "reinterpret")]
1230
F::Reinterpret(_) => FunctionOptions::elementwise(),
1231
F::ExtendConstant => FunctionOptions::groupwise(),
1232
1233
F::RowEncode(..) => FunctionOptions::elementwise(),
1234
#[cfg(feature = "dtype-struct")]
1235
F::RowDecode(..) => FunctionOptions::elementwise(),
1236
}
1237
}
1238
}
1239
1240