Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/mod.rs
7889 views
1
#[cfg(feature = "dtype-array")]
2
mod array;
3
mod binary;
4
#[cfg(feature = "bitwise")]
5
mod bitwise;
6
mod boolean;
7
#[cfg(feature = "business")]
8
mod business;
9
#[cfg(feature = "dtype-categorical")]
10
mod cat;
11
#[cfg(feature = "cov")]
12
mod correlation;
13
#[cfg(feature = "cum_agg")]
14
mod cum;
15
#[cfg(feature = "temporal")]
16
mod datetime;
17
#[cfg(feature = "dtype-extension")]
18
mod extension;
19
#[cfg(feature = "fused")]
20
mod fused;
21
mod list;
22
#[cfg(feature = "ffi_plugin")]
23
pub mod plugin;
24
mod pow;
25
#[cfg(feature = "random")]
26
mod random;
27
#[cfg(feature = "range")]
28
mod range;
29
#[cfg(feature = "rolling_window")]
30
mod rolling;
31
#[cfg(feature = "rolling_window_by")]
32
mod rolling_by;
33
mod row_encode;
34
pub(super) mod schema;
35
#[cfg(feature = "strings")]
36
mod strings;
37
#[cfg(feature = "dtype-struct")]
38
mod struct_;
39
#[cfg(feature = "trigonometry")]
40
mod trigonometry;
41
42
use std::fmt::{Display, Formatter};
43
use std::hash::{Hash, Hasher};
44
45
#[cfg(feature = "dtype-array")]
46
pub use array::IRArrayFunction;
47
#[cfg(feature = "cov")]
48
pub use correlation::IRCorrelationMethod;
49
#[cfg(feature = "fused")]
50
pub use fused::FusedOperator;
51
pub use list::IRListFunction;
52
pub use polars_core::datatypes::ReshapeDimension;
53
use polars_core::prelude::*;
54
use polars_core::series::IsSorted;
55
use polars_core::series::ops::NullBehavior;
56
use polars_core::utils::SuperTypeFlags;
57
#[cfg(feature = "random")]
58
pub use random::IRRandomMethod;
59
use schema::FieldsMapper;
60
61
pub use self::binary::IRBinaryFunction;
62
#[cfg(feature = "bitwise")]
63
pub use self::bitwise::IRBitwiseFunction;
64
pub use self::boolean::IRBooleanFunction;
65
#[cfg(feature = "business")]
66
pub use self::business::IRBusinessFunction;
67
#[cfg(feature = "dtype-categorical")]
68
pub use self::cat::IRCategoricalFunction;
69
#[cfg(feature = "temporal")]
70
pub use self::datetime::IRTemporalFunction;
71
#[cfg(feature = "dtype-extension")]
72
pub use self::extension::IRExtensionFunction;
73
pub use self::pow::IRPowFunction;
74
#[cfg(feature = "range")]
75
pub use self::range::IRRangeFunction;
76
#[cfg(feature = "rolling_window")]
77
pub use self::rolling::IRRollingFunction;
78
#[cfg(feature = "rolling_window_by")]
79
pub use self::rolling_by::IRRollingFunctionBy;
80
pub use self::row_encode::RowEncodingVariant;
81
#[cfg(feature = "strings")]
82
pub use self::strings::IRStringFunction;
83
#[cfg(all(feature = "strings", feature = "regex", feature = "timezones"))]
84
pub use self::strings::TZ_AWARE_RE;
85
#[cfg(feature = "dtype-struct")]
86
pub use self::struct_::IRStructFunction;
87
#[cfg(feature = "trigonometry")]
88
pub use self::trigonometry::IRTrigonometricFunction;
89
use super::*;
90
91
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
92
#[derive(Clone, PartialEq, Debug)]
93
pub enum IRFunctionExpr {
94
// Namespaces
95
#[cfg(feature = "dtype-array")]
96
ArrayExpr(IRArrayFunction),
97
BinaryExpr(IRBinaryFunction),
98
#[cfg(feature = "dtype-categorical")]
99
Categorical(IRCategoricalFunction),
100
#[cfg(feature = "dtype-extension")]
101
Extension(IRExtensionFunction),
102
ListExpr(IRListFunction),
103
#[cfg(feature = "strings")]
104
StringExpr(IRStringFunction),
105
#[cfg(feature = "dtype-struct")]
106
StructExpr(IRStructFunction),
107
#[cfg(feature = "temporal")]
108
TemporalExpr(IRTemporalFunction),
109
#[cfg(feature = "bitwise")]
110
Bitwise(IRBitwiseFunction),
111
112
// Other expressions
113
Boolean(IRBooleanFunction),
114
#[cfg(feature = "business")]
115
Business(IRBusinessFunction),
116
#[cfg(feature = "abs")]
117
Abs,
118
Negate,
119
#[cfg(feature = "hist")]
120
Hist {
121
bin_count: Option<usize>,
122
include_category: bool,
123
include_breakpoint: bool,
124
},
125
NullCount,
126
Pow(IRPowFunction),
127
#[cfg(feature = "row_hash")]
128
Hash(u64, u64, u64, u64),
129
#[cfg(feature = "arg_where")]
130
ArgWhere,
131
#[cfg(feature = "index_of")]
132
IndexOf,
133
#[cfg(feature = "search_sorted")]
134
SearchSorted {
135
side: SearchSortedSide,
136
descending: bool,
137
},
138
#[cfg(feature = "range")]
139
Range(IRRangeFunction),
140
#[cfg(feature = "trigonometry")]
141
Trigonometry(IRTrigonometricFunction),
142
#[cfg(feature = "trigonometry")]
143
Atan2,
144
#[cfg(feature = "sign")]
145
Sign,
146
FillNull,
147
FillNullWithStrategy(FillNullStrategy),
148
#[cfg(feature = "rolling_window")]
149
RollingExpr {
150
function: IRRollingFunction,
151
options: RollingOptionsFixedWindow,
152
},
153
#[cfg(feature = "rolling_window_by")]
154
RollingExprBy {
155
function_by: IRRollingFunctionBy,
156
options: RollingOptionsDynamicWindow,
157
},
158
Rechunk,
159
Append {
160
upcast: bool,
161
},
162
ShiftAndFill,
163
Shift,
164
DropNans,
165
DropNulls,
166
#[cfg(feature = "mode")]
167
Mode {
168
maintain_order: bool,
169
},
170
#[cfg(feature = "moment")]
171
Skew(bool),
172
#[cfg(feature = "moment")]
173
Kurtosis(bool, bool),
174
#[cfg(feature = "dtype-array")]
175
Reshape(Vec<ReshapeDimension>),
176
#[cfg(feature = "repeat_by")]
177
RepeatBy,
178
ArgUnique,
179
ArgMin,
180
ArgMax,
181
ArgSort {
182
descending: bool,
183
nulls_last: bool,
184
},
185
Product,
186
#[cfg(feature = "rank")]
187
Rank {
188
options: RankOptions,
189
seed: Option<u64>,
190
},
191
Repeat,
192
#[cfg(feature = "round_series")]
193
Clip {
194
has_min: bool,
195
has_max: bool,
196
},
197
#[cfg(feature = "dtype-struct")]
198
AsStruct,
199
#[cfg(feature = "top_k")]
200
TopK {
201
descending: bool,
202
},
203
#[cfg(feature = "top_k")]
204
TopKBy {
205
descending: Vec<bool>,
206
},
207
#[cfg(feature = "cum_agg")]
208
CumCount {
209
reverse: bool,
210
},
211
#[cfg(feature = "cum_agg")]
212
CumSum {
213
reverse: bool,
214
},
215
#[cfg(feature = "cum_agg")]
216
CumProd {
217
reverse: bool,
218
},
219
#[cfg(feature = "cum_agg")]
220
CumMin {
221
reverse: bool,
222
},
223
#[cfg(feature = "cum_agg")]
224
CumMax {
225
reverse: bool,
226
},
227
Reverse,
228
#[cfg(feature = "dtype-struct")]
229
ValueCounts {
230
sort: bool,
231
parallel: bool,
232
name: PlSmallStr,
233
normalize: bool,
234
},
235
#[cfg(feature = "unique_counts")]
236
UniqueCounts,
237
#[cfg(feature = "approx_unique")]
238
ApproxNUnique,
239
Coalesce,
240
#[cfg(feature = "diff")]
241
Diff(NullBehavior),
242
#[cfg(feature = "pct_change")]
243
PctChange,
244
#[cfg(feature = "interpolate")]
245
Interpolate(InterpolationMethod),
246
#[cfg(feature = "interpolate_by")]
247
InterpolateBy,
248
#[cfg(feature = "log")]
249
Entropy {
250
base: f64,
251
normalize: bool,
252
},
253
#[cfg(feature = "log")]
254
Log,
255
#[cfg(feature = "log")]
256
Log1p,
257
#[cfg(feature = "log")]
258
Exp,
259
Unique(bool),
260
#[cfg(feature = "round_series")]
261
Round {
262
decimals: u32,
263
mode: RoundMode,
264
},
265
#[cfg(feature = "round_series")]
266
RoundSF {
267
digits: i32,
268
},
269
#[cfg(feature = "round_series")]
270
Floor,
271
#[cfg(feature = "round_series")]
272
Ceil,
273
#[cfg(feature = "fused")]
274
Fused(fused::FusedOperator),
275
ConcatExpr(bool),
276
#[cfg(feature = "cov")]
277
Correlation {
278
method: correlation::IRCorrelationMethod,
279
},
280
#[cfg(feature = "peaks")]
281
PeakMin,
282
#[cfg(feature = "peaks")]
283
PeakMax,
284
#[cfg(feature = "cutqcut")]
285
Cut {
286
breaks: Vec<f64>,
287
labels: Option<Vec<PlSmallStr>>,
288
left_closed: bool,
289
include_breaks: bool,
290
},
291
#[cfg(feature = "cutqcut")]
292
QCut {
293
probs: Vec<f64>,
294
labels: Option<Vec<PlSmallStr>>,
295
left_closed: bool,
296
allow_duplicates: bool,
297
include_breaks: bool,
298
},
299
#[cfg(feature = "rle")]
300
RLE,
301
#[cfg(feature = "rle")]
302
RLEID,
303
ToPhysical,
304
#[cfg(feature = "random")]
305
Random {
306
method: IRRandomMethod,
307
seed: Option<u64>,
308
},
309
SetSortedFlag(IsSorted),
310
#[cfg(feature = "ffi_plugin")]
311
/// Creating this node is unsafe
312
/// This will lead to calls over FFI.
313
FfiPlugin {
314
flags: FunctionOptions,
315
/// Shared library.
316
lib: PlSmallStr,
317
/// Identifier in the shared lib.
318
symbol: PlSmallStr,
319
/// Pickle serialized keyword arguments.
320
kwargs: Arc<[u8]>,
321
},
322
323
FoldHorizontal {
324
callback: PlanCallback<(Series, Series), Series>,
325
returns_scalar: bool,
326
return_dtype: Option<DataType>,
327
},
328
ReduceHorizontal {
329
callback: PlanCallback<(Series, Series), Series>,
330
returns_scalar: bool,
331
return_dtype: Option<DataType>,
332
},
333
#[cfg(feature = "dtype-struct")]
334
CumReduceHorizontal {
335
callback: PlanCallback<(Series, Series), Series>,
336
returns_scalar: bool,
337
return_dtype: Option<DataType>,
338
},
339
#[cfg(feature = "dtype-struct")]
340
CumFoldHorizontal {
341
callback: PlanCallback<(Series, Series), Series>,
342
returns_scalar: bool,
343
return_dtype: Option<DataType>,
344
include_init: bool,
345
},
346
347
MaxHorizontal,
348
MinHorizontal,
349
SumHorizontal {
350
ignore_nulls: bool,
351
},
352
MeanHorizontal {
353
ignore_nulls: bool,
354
},
355
#[cfg(feature = "ewma")]
356
EwmMean {
357
options: EWMOptions,
358
},
359
#[cfg(feature = "ewma_by")]
360
EwmMeanBy {
361
half_life: Duration,
362
},
363
#[cfg(feature = "ewma")]
364
EwmStd {
365
options: EWMOptions,
366
},
367
#[cfg(feature = "ewma")]
368
EwmVar {
369
options: EWMOptions,
370
},
371
#[cfg(feature = "replace")]
372
Replace,
373
#[cfg(feature = "replace")]
374
ReplaceStrict {
375
return_dtype: Option<DataType>,
376
},
377
GatherEvery {
378
n: usize,
379
offset: usize,
380
},
381
#[cfg(feature = "reinterpret")]
382
Reinterpret(bool),
383
ExtendConstant,
384
385
RowEncode(Vec<DataType>, RowEncodingVariant),
386
#[cfg(feature = "dtype-struct")]
387
RowDecode(Vec<Field>, RowEncodingVariant),
388
}
389
390
impl Hash for IRFunctionExpr {
391
fn hash<H: Hasher>(&self, state: &mut H) {
392
std::mem::discriminant(self).hash(state);
393
use IRFunctionExpr::*;
394
match self {
395
// Namespaces
396
#[cfg(feature = "dtype-array")]
397
ArrayExpr(f) => f.hash(state),
398
BinaryExpr(f) => f.hash(state),
399
#[cfg(feature = "dtype-categorical")]
400
Categorical(f) => f.hash(state),
401
#[cfg(feature = "dtype-extension")]
402
Extension(f) => f.hash(state),
403
ListExpr(f) => f.hash(state),
404
#[cfg(feature = "strings")]
405
StringExpr(f) => f.hash(state),
406
#[cfg(feature = "dtype-struct")]
407
StructExpr(f) => f.hash(state),
408
#[cfg(feature = "temporal")]
409
TemporalExpr(f) => f.hash(state),
410
#[cfg(feature = "bitwise")]
411
Bitwise(f) => f.hash(state),
412
413
// Other expressions
414
Boolean(f) => f.hash(state),
415
#[cfg(feature = "business")]
416
Business(f) => f.hash(state),
417
Pow(f) => f.hash(state),
418
#[cfg(feature = "index_of")]
419
IndexOf => {},
420
#[cfg(feature = "search_sorted")]
421
SearchSorted { side, descending } => {
422
side.hash(state);
423
descending.hash(state);
424
},
425
#[cfg(feature = "random")]
426
Random { method, .. } => method.hash(state),
427
#[cfg(feature = "cov")]
428
Correlation { method, .. } => method.hash(state),
429
#[cfg(feature = "range")]
430
Range(f) => f.hash(state),
431
#[cfg(feature = "trigonometry")]
432
Trigonometry(f) => f.hash(state),
433
#[cfg(feature = "fused")]
434
Fused(f) => f.hash(state),
435
#[cfg(feature = "diff")]
436
Diff(null_behavior) => null_behavior.hash(state),
437
#[cfg(feature = "interpolate")]
438
Interpolate(f) => f.hash(state),
439
#[cfg(feature = "interpolate_by")]
440
InterpolateBy => {},
441
#[cfg(feature = "ffi_plugin")]
442
FfiPlugin {
443
flags: _,
444
lib,
445
symbol,
446
kwargs,
447
} => {
448
kwargs.hash(state);
449
lib.hash(state);
450
symbol.hash(state);
451
},
452
453
FoldHorizontal {
454
callback,
455
returns_scalar,
456
return_dtype,
457
}
458
| ReduceHorizontal {
459
callback,
460
returns_scalar,
461
return_dtype,
462
} => {
463
callback.hash(state);
464
returns_scalar.hash(state);
465
return_dtype.hash(state);
466
},
467
#[cfg(feature = "dtype-struct")]
468
CumReduceHorizontal {
469
callback,
470
returns_scalar,
471
return_dtype,
472
} => {
473
callback.hash(state);
474
returns_scalar.hash(state);
475
return_dtype.hash(state);
476
},
477
#[cfg(feature = "dtype-struct")]
478
CumFoldHorizontal {
479
callback,
480
returns_scalar,
481
return_dtype,
482
include_init,
483
} => {
484
callback.hash(state);
485
returns_scalar.hash(state);
486
return_dtype.hash(state);
487
include_init.hash(state);
488
},
489
490
SumHorizontal { ignore_nulls } | MeanHorizontal { ignore_nulls } => {
491
ignore_nulls.hash(state)
492
},
493
MaxHorizontal | MinHorizontal | DropNans | DropNulls | Reverse | ArgUnique | ArgMin
494
| ArgMax | Product | Shift | ShiftAndFill | Rechunk => {},
495
Append { upcast } => {
496
upcast.hash(state);
497
},
498
ArgSort {
499
descending,
500
nulls_last,
501
} => {
502
descending.hash(state);
503
nulls_last.hash(state);
504
},
505
#[cfg(feature = "mode")]
506
Mode { maintain_order } => {
507
maintain_order.hash(state);
508
},
509
#[cfg(feature = "abs")]
510
Abs => {},
511
Negate => {},
512
NullCount => {},
513
#[cfg(feature = "arg_where")]
514
ArgWhere => {},
515
#[cfg(feature = "trigonometry")]
516
Atan2 => {},
517
#[cfg(feature = "dtype-struct")]
518
AsStruct => {},
519
#[cfg(feature = "sign")]
520
Sign => {},
521
#[cfg(feature = "row_hash")]
522
Hash(a, b, c, d) => (a, b, c, d).hash(state),
523
FillNull => {},
524
#[cfg(feature = "rolling_window")]
525
RollingExpr { function, options } => {
526
function.hash(state);
527
options.hash(state);
528
},
529
#[cfg(feature = "rolling_window_by")]
530
RollingExprBy {
531
function_by,
532
options,
533
} => {
534
function_by.hash(state);
535
options.hash(state);
536
},
537
#[cfg(feature = "moment")]
538
Skew(a) => a.hash(state),
539
#[cfg(feature = "moment")]
540
Kurtosis(a, b) => {
541
a.hash(state);
542
b.hash(state);
543
},
544
Repeat => {},
545
#[cfg(feature = "rank")]
546
Rank { options, seed } => {
547
options.hash(state);
548
seed.hash(state);
549
},
550
#[cfg(feature = "round_series")]
551
Clip { has_min, has_max } => {
552
has_min.hash(state);
553
has_max.hash(state);
554
},
555
#[cfg(feature = "top_k")]
556
TopK { descending } => descending.hash(state),
557
#[cfg(feature = "cum_agg")]
558
CumCount { reverse } => reverse.hash(state),
559
#[cfg(feature = "cum_agg")]
560
CumSum { reverse } => reverse.hash(state),
561
#[cfg(feature = "cum_agg")]
562
CumProd { reverse } => reverse.hash(state),
563
#[cfg(feature = "cum_agg")]
564
CumMin { reverse } => reverse.hash(state),
565
#[cfg(feature = "cum_agg")]
566
CumMax { reverse } => reverse.hash(state),
567
#[cfg(feature = "dtype-struct")]
568
ValueCounts {
569
sort,
570
parallel,
571
name,
572
normalize,
573
} => {
574
sort.hash(state);
575
parallel.hash(state);
576
name.hash(state);
577
normalize.hash(state);
578
},
579
#[cfg(feature = "unique_counts")]
580
UniqueCounts => {},
581
#[cfg(feature = "approx_unique")]
582
ApproxNUnique => {},
583
Coalesce => {},
584
#[cfg(feature = "pct_change")]
585
PctChange => {},
586
#[cfg(feature = "log")]
587
Entropy { base, normalize } => {
588
base.to_bits().hash(state);
589
normalize.hash(state);
590
},
591
#[cfg(feature = "log")]
592
Log => {},
593
#[cfg(feature = "log")]
594
Log1p => {},
595
#[cfg(feature = "log")]
596
Exp => {},
597
Unique(a) => a.hash(state),
598
#[cfg(feature = "round_series")]
599
Round { decimals, mode } => {
600
decimals.hash(state);
601
mode.hash(state);
602
},
603
#[cfg(feature = "round_series")]
604
IRFunctionExpr::RoundSF { digits } => digits.hash(state),
605
#[cfg(feature = "round_series")]
606
IRFunctionExpr::Floor => {},
607
#[cfg(feature = "round_series")]
608
Ceil => {},
609
ConcatExpr(a) => a.hash(state),
610
#[cfg(feature = "peaks")]
611
PeakMin => {},
612
#[cfg(feature = "peaks")]
613
PeakMax => {},
614
#[cfg(feature = "cutqcut")]
615
Cut {
616
breaks,
617
labels,
618
left_closed,
619
include_breaks,
620
} => {
621
let slice = bytemuck::cast_slice::<_, u64>(breaks);
622
slice.hash(state);
623
labels.hash(state);
624
left_closed.hash(state);
625
include_breaks.hash(state);
626
},
627
#[cfg(feature = "dtype-array")]
628
Reshape(dims) => dims.hash(state),
629
#[cfg(feature = "repeat_by")]
630
RepeatBy => {},
631
#[cfg(feature = "cutqcut")]
632
QCut {
633
probs,
634
labels,
635
left_closed,
636
allow_duplicates,
637
include_breaks,
638
} => {
639
let slice = bytemuck::cast_slice::<_, u64>(probs);
640
slice.hash(state);
641
labels.hash(state);
642
left_closed.hash(state);
643
allow_duplicates.hash(state);
644
include_breaks.hash(state);
645
},
646
#[cfg(feature = "rle")]
647
RLE => {},
648
#[cfg(feature = "rle")]
649
RLEID => {},
650
ToPhysical => {},
651
SetSortedFlag(is_sorted) => is_sorted.hash(state),
652
#[cfg(feature = "ewma")]
653
EwmMean { options } => options.hash(state),
654
#[cfg(feature = "ewma_by")]
655
EwmMeanBy { half_life } => (half_life).hash(state),
656
#[cfg(feature = "ewma")]
657
EwmStd { options } => options.hash(state),
658
#[cfg(feature = "ewma")]
659
EwmVar { options } => options.hash(state),
660
#[cfg(feature = "hist")]
661
Hist {
662
bin_count,
663
include_category,
664
include_breakpoint,
665
} => {
666
bin_count.hash(state);
667
include_category.hash(state);
668
include_breakpoint.hash(state);
669
},
670
#[cfg(feature = "replace")]
671
Replace => {},
672
#[cfg(feature = "replace")]
673
ReplaceStrict { return_dtype } => return_dtype.hash(state),
674
FillNullWithStrategy(strategy) => strategy.hash(state),
675
GatherEvery { n, offset } => (n, offset).hash(state),
676
#[cfg(feature = "reinterpret")]
677
Reinterpret(signed) => signed.hash(state),
678
ExtendConstant => {},
679
#[cfg(feature = "top_k")]
680
TopKBy { descending } => descending.hash(state),
681
682
RowEncode(dts, variants) => {
683
dts.hash(state);
684
variants.hash(state);
685
},
686
#[cfg(feature = "dtype-struct")]
687
RowDecode(fs, variants) => {
688
fs.hash(state);
689
variants.hash(state);
690
},
691
}
692
}
693
}
694
695
impl Display for IRFunctionExpr {
696
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
697
use IRFunctionExpr::*;
698
let s = match self {
699
// Namespaces
700
#[cfg(feature = "dtype-array")]
701
ArrayExpr(func) => return write!(f, "{func}"),
702
BinaryExpr(func) => return write!(f, "{func}"),
703
#[cfg(feature = "dtype-categorical")]
704
Categorical(func) => return write!(f, "{func}"),
705
#[cfg(feature = "dtype-extension")]
706
Extension(func) => return write!(f, "{func}"),
707
ListExpr(func) => return write!(f, "{func}"),
708
#[cfg(feature = "strings")]
709
StringExpr(func) => return write!(f, "{func}"),
710
#[cfg(feature = "dtype-struct")]
711
StructExpr(func) => return write!(f, "{func}"),
712
#[cfg(feature = "temporal")]
713
TemporalExpr(func) => return write!(f, "{func}"),
714
#[cfg(feature = "bitwise")]
715
Bitwise(func) => return write!(f, "bitwise_{func}"),
716
717
// Other expressions
718
Boolean(func) => return write!(f, "{func}"),
719
#[cfg(feature = "business")]
720
Business(func) => return write!(f, "{func}"),
721
#[cfg(feature = "abs")]
722
Abs => "abs",
723
Negate => "negate",
724
NullCount => "null_count",
725
Pow(func) => return write!(f, "{func}"),
726
#[cfg(feature = "row_hash")]
727
Hash(_, _, _, _) => "hash",
728
#[cfg(feature = "arg_where")]
729
ArgWhere => "arg_where",
730
#[cfg(feature = "index_of")]
731
IndexOf => "index_of",
732
#[cfg(feature = "search_sorted")]
733
SearchSorted { .. } => "search_sorted",
734
#[cfg(feature = "range")]
735
Range(func) => return write!(f, "{func}"),
736
#[cfg(feature = "trigonometry")]
737
Trigonometry(func) => return write!(f, "{func}"),
738
#[cfg(feature = "trigonometry")]
739
Atan2 => return write!(f, "arctan2"),
740
#[cfg(feature = "sign")]
741
Sign => "sign",
742
FillNull => "fill_null",
743
#[cfg(feature = "rolling_window")]
744
RollingExpr { function, .. } => return write!(f, "{function}"),
745
#[cfg(feature = "rolling_window_by")]
746
RollingExprBy { function_by, .. } => return write!(f, "{function_by}"),
747
Rechunk => "rechunk",
748
Append { .. } => "append",
749
ShiftAndFill => "shift_and_fill",
750
DropNans => "drop_nans",
751
DropNulls => "drop_nulls",
752
#[cfg(feature = "mode")]
753
Mode { maintain_order } => {
754
if *maintain_order {
755
"mode_stable"
756
} else {
757
"mode"
758
}
759
},
760
#[cfg(feature = "moment")]
761
Skew(_) => "skew",
762
#[cfg(feature = "moment")]
763
Kurtosis(..) => "kurtosis",
764
ArgUnique => "arg_unique",
765
ArgMin => "arg_min",
766
ArgMax => "arg_max",
767
ArgSort { .. } => "arg_sort",
768
Product => "product",
769
Repeat => "repeat",
770
#[cfg(feature = "rank")]
771
Rank { .. } => "rank",
772
#[cfg(feature = "round_series")]
773
Clip { has_min, has_max } => match (has_min, has_max) {
774
(true, true) => "clip",
775
(false, true) => "clip_max",
776
(true, false) => "clip_min",
777
_ => unreachable!(),
778
},
779
#[cfg(feature = "dtype-struct")]
780
AsStruct => "as_struct",
781
#[cfg(feature = "top_k")]
782
TopK { descending } => {
783
if *descending {
784
"bottom_k"
785
} else {
786
"top_k"
787
}
788
},
789
#[cfg(feature = "top_k")]
790
TopKBy { .. } => "top_k_by",
791
Shift => "shift",
792
#[cfg(feature = "cum_agg")]
793
CumCount { .. } => "cum_count",
794
#[cfg(feature = "cum_agg")]
795
CumSum { .. } => "cum_sum",
796
#[cfg(feature = "cum_agg")]
797
CumProd { .. } => "cum_prod",
798
#[cfg(feature = "cum_agg")]
799
CumMin { .. } => "cum_min",
800
#[cfg(feature = "cum_agg")]
801
CumMax { .. } => "cum_max",
802
#[cfg(feature = "dtype-struct")]
803
ValueCounts { .. } => "value_counts",
804
#[cfg(feature = "unique_counts")]
805
UniqueCounts => "unique_counts",
806
Reverse => "reverse",
807
#[cfg(feature = "approx_unique")]
808
ApproxNUnique => "approx_n_unique",
809
Coalesce => "coalesce",
810
#[cfg(feature = "diff")]
811
Diff(_) => "diff",
812
#[cfg(feature = "pct_change")]
813
PctChange => "pct_change",
814
#[cfg(feature = "interpolate")]
815
Interpolate(_) => "interpolate",
816
#[cfg(feature = "interpolate_by")]
817
InterpolateBy => "interpolate_by",
818
#[cfg(feature = "log")]
819
Entropy { .. } => "entropy",
820
#[cfg(feature = "log")]
821
Log => "log",
822
#[cfg(feature = "log")]
823
Log1p => "log1p",
824
#[cfg(feature = "log")]
825
Exp => "exp",
826
Unique(stable) => {
827
if *stable {
828
"unique_stable"
829
} else {
830
"unique"
831
}
832
},
833
#[cfg(feature = "round_series")]
834
Round { .. } => "round",
835
#[cfg(feature = "round_series")]
836
RoundSF { .. } => "round_sig_figs",
837
#[cfg(feature = "round_series")]
838
Floor => "floor",
839
#[cfg(feature = "round_series")]
840
Ceil => "ceil",
841
#[cfg(feature = "fused")]
842
Fused(fused) => return Display::fmt(fused, f),
843
ConcatExpr(_) => "concat_expr",
844
#[cfg(feature = "cov")]
845
Correlation { method, .. } => return Display::fmt(method, f),
846
#[cfg(feature = "peaks")]
847
PeakMin => "peak_min",
848
#[cfg(feature = "peaks")]
849
PeakMax => "peak_max",
850
#[cfg(feature = "cutqcut")]
851
Cut { .. } => "cut",
852
#[cfg(feature = "cutqcut")]
853
QCut { .. } => "qcut",
854
#[cfg(feature = "dtype-array")]
855
Reshape(_) => "reshape",
856
#[cfg(feature = "repeat_by")]
857
RepeatBy => "repeat_by",
858
#[cfg(feature = "rle")]
859
RLE => "rle",
860
#[cfg(feature = "rle")]
861
RLEID => "rle_id",
862
ToPhysical => "to_physical",
863
#[cfg(feature = "random")]
864
Random { method, .. } => method.into(),
865
SetSortedFlag(_) => "set_sorted",
866
#[cfg(feature = "ffi_plugin")]
867
FfiPlugin { lib, symbol, .. } => return write!(f, "{lib}:{symbol}"),
868
869
FoldHorizontal { .. } => "fold",
870
ReduceHorizontal { .. } => "reduce",
871
#[cfg(feature = "dtype-struct")]
872
CumReduceHorizontal { .. } => "cum_reduce",
873
#[cfg(feature = "dtype-struct")]
874
CumFoldHorizontal { .. } => "cum_fold",
875
876
MaxHorizontal => "max_horizontal",
877
MinHorizontal => "min_horizontal",
878
SumHorizontal { .. } => "sum_horizontal",
879
MeanHorizontal { .. } => "mean_horizontal",
880
#[cfg(feature = "ewma")]
881
EwmMean { .. } => "ewm_mean",
882
#[cfg(feature = "ewma_by")]
883
EwmMeanBy { .. } => "ewm_mean_by",
884
#[cfg(feature = "ewma")]
885
EwmStd { .. } => "ewm_std",
886
#[cfg(feature = "ewma")]
887
EwmVar { .. } => "ewm_var",
888
#[cfg(feature = "hist")]
889
Hist { .. } => "hist",
890
#[cfg(feature = "replace")]
891
Replace => "replace",
892
#[cfg(feature = "replace")]
893
ReplaceStrict { .. } => "replace_strict",
894
FillNullWithStrategy(_) => "fill_null_with_strategy",
895
GatherEvery { .. } => "gather_every",
896
#[cfg(feature = "reinterpret")]
897
Reinterpret(_) => "reinterpret",
898
ExtendConstant => "extend_constant",
899
900
RowEncode(..) => "row_encode",
901
#[cfg(feature = "dtype-struct")]
902
RowDecode(..) => "row_decode",
903
};
904
write!(f, "{s}")
905
}
906
}
907
908
#[macro_export]
909
macro_rules! wrap {
910
($e:expr) => {
911
SpecialEq::new(Arc::new($e))
912
};
913
914
($e:expr, $($args:expr),*) => {{
915
let f = move |s: &mut [Column]| {
916
$e(s, $($args),*)
917
};
918
919
SpecialEq::new(Arc::new(f))
920
}};
921
}
922
923
/// `Fn(&[Column], args)`
924
/// * all expression arguments are in the slice.
925
/// * the first element is the root expression.
926
#[macro_export]
927
macro_rules! map_as_slice {
928
($func:path) => {{
929
let f = move |s: &mut [Column]| {
930
$func(s)
931
};
932
933
SpecialEq::new(Arc::new(f))
934
}};
935
936
($func:path, $($args:expr),*) => {{
937
let f = move |s: &mut [Column]| {
938
$func(s, $($args),*)
939
};
940
941
SpecialEq::new(Arc::new(f))
942
}};
943
}
944
945
/// * `FnOnce(Series)`
946
/// * `FnOnce(Series, args)`
947
#[macro_export]
948
macro_rules! map_owned {
949
($func:path) => {{
950
let f = move |c: &mut [Column]| {
951
let c = std::mem::take(&mut c[0]);
952
$func(c)
953
};
954
955
SpecialEq::new(Arc::new(f))
956
}};
957
958
($func:path, $($args:expr),*) => {{
959
let f = move |c: &mut [Column]| {
960
let c = std::mem::take(&mut c[0]);
961
$func(c, $($args),*)
962
};
963
964
SpecialEq::new(Arc::new(f))
965
}};
966
}
967
968
/// `Fn(&Series, args)`
969
#[macro_export]
970
macro_rules! map {
971
($func:path) => {{
972
let f = move |c: &mut [Column]| {
973
let c = &c[0];
974
$func(c)
975
};
976
977
SpecialEq::new(Arc::new(f))
978
}};
979
980
($func:path, $($args:expr),*) => {{
981
let f = move |c: &mut [Column]| {
982
let c = &c[0];
983
$func(c, $($args),*)
984
};
985
986
SpecialEq::new(Arc::new(f))
987
}};
988
}
989
990
impl IRFunctionExpr {
991
pub fn function_options(&self) -> FunctionOptions {
992
use IRFunctionExpr as F;
993
match self {
994
#[cfg(feature = "dtype-array")]
995
F::ArrayExpr(e) => e.function_options(),
996
F::BinaryExpr(e) => e.function_options(),
997
#[cfg(feature = "dtype-categorical")]
998
F::Categorical(e) => e.function_options(),
999
#[cfg(feature = "dtype-extension")]
1000
F::Extension(e) => e.function_options(),
1001
F::ListExpr(e) => e.function_options(),
1002
#[cfg(feature = "strings")]
1003
F::StringExpr(e) => e.function_options(),
1004
#[cfg(feature = "dtype-struct")]
1005
F::StructExpr(e) => e.function_options(),
1006
#[cfg(feature = "temporal")]
1007
F::TemporalExpr(e) => e.function_options(),
1008
#[cfg(feature = "bitwise")]
1009
F::Bitwise(e) => e.function_options(),
1010
F::Boolean(e) => e.function_options(),
1011
#[cfg(feature = "business")]
1012
F::Business(e) => e.function_options(),
1013
F::Pow(e) => e.function_options(),
1014
#[cfg(feature = "range")]
1015
F::Range(e) => e.function_options(),
1016
#[cfg(feature = "abs")]
1017
F::Abs => FunctionOptions::elementwise(),
1018
F::Negate => FunctionOptions::elementwise(),
1019
#[cfg(feature = "hist")]
1020
F::Hist { .. } => FunctionOptions::groupwise(),
1021
F::NullCount => FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING),
1022
#[cfg(feature = "row_hash")]
1023
F::Hash(_, _, _, _) => FunctionOptions::elementwise(),
1024
#[cfg(feature = "arg_where")]
1025
F::ArgWhere => FunctionOptions::groupwise(),
1026
#[cfg(feature = "index_of")]
1027
F::IndexOf => {
1028
FunctionOptions::aggregation().with_casting_rules(CastingRules::FirstArgLossless)
1029
},
1030
#[cfg(feature = "search_sorted")]
1031
F::SearchSorted { .. } => FunctionOptions::groupwise().with_supertyping(
1032
(SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(),
1033
),
1034
#[cfg(feature = "trigonometry")]
1035
F::Trigonometry(_) => FunctionOptions::elementwise(),
1036
#[cfg(feature = "trigonometry")]
1037
F::Atan2 => FunctionOptions::elementwise(),
1038
#[cfg(feature = "sign")]
1039
F::Sign => FunctionOptions::elementwise(),
1040
F::FillNull => FunctionOptions::elementwise().with_supertyping(Default::default()),
1041
F::FillNullWithStrategy(strategy) if strategy.is_elementwise() => {
1042
FunctionOptions::elementwise()
1043
},
1044
F::FillNullWithStrategy(_) => FunctionOptions::length_preserving(),
1045
#[cfg(feature = "rolling_window")]
1046
F::RollingExpr { .. } => FunctionOptions::length_preserving(),
1047
#[cfg(feature = "rolling_window_by")]
1048
F::RollingExprBy { .. } => FunctionOptions::length_preserving(),
1049
F::Rechunk => FunctionOptions::length_preserving(),
1050
F::Append { .. } => FunctionOptions::groupwise(),
1051
F::ShiftAndFill => FunctionOptions::length_preserving(),
1052
F::Shift => FunctionOptions::length_preserving(),
1053
F::DropNans => {
1054
FunctionOptions::row_separable().flag(FunctionFlags::NON_ORDER_PRODUCING)
1055
},
1056
F::DropNulls => FunctionOptions::row_separable()
1057
.flag(FunctionFlags::ALLOW_EMPTY_INPUTS | FunctionFlags::NON_ORDER_PRODUCING),
1058
#[cfg(feature = "mode")]
1059
F::Mode { maintain_order } => FunctionOptions::groupwise().with_flags(|f| {
1060
let f = f | FunctionFlags::NON_ORDER_PRODUCING;
1061
1062
if !*maintain_order {
1063
f | FunctionFlags::NON_ORDER_OBSERVING | FunctionFlags::TERMINATES_INPUT_ORDER
1064
} else {
1065
f
1066
}
1067
}),
1068
#[cfg(feature = "moment")]
1069
F::Skew(_) => FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING),
1070
#[cfg(feature = "moment")]
1071
F::Kurtosis(_, _) => {
1072
FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING)
1073
},
1074
#[cfg(feature = "dtype-array")]
1075
F::Reshape(dims) => {
1076
if dims.len() == 1 && dims[0] == ReshapeDimension::Infer {
1077
FunctionOptions::row_separable()
1078
} else {
1079
FunctionOptions::groupwise()
1080
}
1081
},
1082
#[cfg(feature = "repeat_by")]
1083
F::RepeatBy => FunctionOptions::elementwise(),
1084
F::ArgUnique => FunctionOptions::groupwise(),
1085
F::ArgMin | F::ArgMax => FunctionOptions::aggregation(),
1086
F::ArgSort { .. } => FunctionOptions::length_preserving(),
1087
F::Product => FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING),
1088
#[cfg(feature = "rank")]
1089
F::Rank { .. } => FunctionOptions::length_preserving(),
1090
F::Repeat => {
1091
FunctionOptions::groupwise().with_flags(|f| f | FunctionFlags::ALLOW_RENAME)
1092
},
1093
#[cfg(feature = "round_series")]
1094
F::Clip { .. } => FunctionOptions::elementwise(),
1095
#[cfg(feature = "dtype-struct")]
1096
F::AsStruct => FunctionOptions::elementwise().with_flags(|f| {
1097
f | FunctionFlags::PASS_NAME_TO_APPLY | FunctionFlags::INPUT_WILDCARD_EXPANSION
1098
}),
1099
#[cfg(feature = "top_k")]
1100
F::TopK { .. } => FunctionOptions::groupwise(),
1101
#[cfg(feature = "top_k")]
1102
F::TopKBy { .. } => FunctionOptions::groupwise(),
1103
#[cfg(feature = "cum_agg")]
1104
F::CumCount { .. }
1105
| F::CumSum { .. }
1106
| F::CumProd { .. }
1107
| F::CumMin { .. }
1108
| F::CumMax { .. } => FunctionOptions::length_preserving(),
1109
F::Reverse => FunctionOptions::length_preserving()
1110
.with_flags(|f| f | FunctionFlags::NON_ORDER_OBSERVING),
1111
#[cfg(feature = "dtype-struct")]
1112
F::ValueCounts { sort, .. } => FunctionOptions::groupwise().with_flags(|mut f| {
1113
if !sort {
1114
f |= FunctionFlags::TERMINATES_INPUT_ORDER | FunctionFlags::NON_ORDER_PRODUCING
1115
}
1116
f | FunctionFlags::PASS_NAME_TO_APPLY | FunctionFlags::NON_ORDER_OBSERVING
1117
}),
1118
#[cfg(feature = "unique_counts")]
1119
F::UniqueCounts => FunctionOptions::groupwise(),
1120
#[cfg(feature = "approx_unique")]
1121
F::ApproxNUnique => {
1122
FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING)
1123
},
1124
F::Coalesce => FunctionOptions::elementwise()
1125
.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION)
1126
.with_supertyping(Default::default()),
1127
#[cfg(feature = "diff")]
1128
F::Diff(NullBehavior::Drop) => FunctionOptions::groupwise(),
1129
#[cfg(feature = "diff")]
1130
F::Diff(NullBehavior::Ignore) => FunctionOptions::length_preserving(),
1131
#[cfg(feature = "pct_change")]
1132
F::PctChange => FunctionOptions::length_preserving(),
1133
#[cfg(feature = "interpolate")]
1134
F::Interpolate(_) => FunctionOptions::length_preserving(),
1135
#[cfg(feature = "interpolate_by")]
1136
F::InterpolateBy => FunctionOptions::length_preserving(),
1137
#[cfg(feature = "log")]
1138
F::Log | F::Log1p | F::Exp => FunctionOptions::elementwise(),
1139
#[cfg(feature = "log")]
1140
F::Entropy { .. } => {
1141
FunctionOptions::aggregation().flag(FunctionFlags::NON_ORDER_OBSERVING)
1142
},
1143
F::Unique(maintain_order) => FunctionOptions::groupwise().with_flags(|f| {
1144
let f = f | FunctionFlags::NON_ORDER_PRODUCING;
1145
1146
if !*maintain_order {
1147
f | FunctionFlags::NON_ORDER_OBSERVING | FunctionFlags::TERMINATES_INPUT_ORDER
1148
} else {
1149
f
1150
}
1151
}),
1152
#[cfg(feature = "round_series")]
1153
F::Round { .. } | F::RoundSF { .. } | F::Floor | F::Ceil => {
1154
FunctionOptions::elementwise()
1155
},
1156
#[cfg(feature = "fused")]
1157
F::Fused(_) => FunctionOptions::elementwise(),
1158
F::ConcatExpr(_) => FunctionOptions::groupwise()
1159
.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION)
1160
.with_supertyping(Default::default()),
1161
#[cfg(feature = "cov")]
1162
F::Correlation { .. } => {
1163
FunctionOptions::aggregation().with_supertyping(Default::default())
1164
},
1165
#[cfg(feature = "peaks")]
1166
F::PeakMin | F::PeakMax => FunctionOptions::length_preserving(),
1167
#[cfg(feature = "cutqcut")]
1168
F::Cut { .. } | F::QCut { .. } => FunctionOptions::length_preserving()
1169
.with_flags(|f| f | FunctionFlags::PASS_NAME_TO_APPLY),
1170
#[cfg(feature = "rle")]
1171
F::RLE => FunctionOptions::groupwise(),
1172
#[cfg(feature = "rle")]
1173
F::RLEID => FunctionOptions::length_preserving(),
1174
F::ToPhysical => FunctionOptions::elementwise(),
1175
#[cfg(feature = "random")]
1176
F::Random {
1177
method: IRRandomMethod::Sample { .. },
1178
..
1179
} => FunctionOptions::groupwise(),
1180
#[cfg(feature = "random")]
1181
F::Random {
1182
method: IRRandomMethod::Shuffle,
1183
..
1184
} => FunctionOptions::length_preserving(),
1185
F::SetSortedFlag(_) => FunctionOptions::elementwise(),
1186
#[cfg(feature = "ffi_plugin")]
1187
F::FfiPlugin { flags, .. } => *flags,
1188
F::MaxHorizontal | F::MinHorizontal => FunctionOptions::elementwise().with_flags(|f| {
1189
f | FunctionFlags::INPUT_WILDCARD_EXPANSION | FunctionFlags::ALLOW_RENAME
1190
}),
1191
F::MeanHorizontal { .. } | F::SumHorizontal { .. } => FunctionOptions::elementwise()
1192
.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),
1193
1194
F::FoldHorizontal { returns_scalar, .. }
1195
| F::ReduceHorizontal { returns_scalar, .. } => FunctionOptions::groupwise()
1196
.with_flags(|mut f| {
1197
f |= FunctionFlags::INPUT_WILDCARD_EXPANSION;
1198
if *returns_scalar {
1199
f |= FunctionFlags::RETURNS_SCALAR;
1200
}
1201
f
1202
}),
1203
#[cfg(feature = "dtype-struct")]
1204
F::CumFoldHorizontal { returns_scalar, .. }
1205
| F::CumReduceHorizontal { returns_scalar, .. } => FunctionOptions::groupwise()
1206
.with_flags(|mut f| {
1207
f |= FunctionFlags::INPUT_WILDCARD_EXPANSION;
1208
if *returns_scalar {
1209
f |= FunctionFlags::RETURNS_SCALAR;
1210
}
1211
f
1212
}),
1213
#[cfg(feature = "ewma")]
1214
F::EwmMean { .. } | F::EwmStd { .. } | F::EwmVar { .. } => {
1215
FunctionOptions::length_preserving()
1216
},
1217
#[cfg(feature = "ewma_by")]
1218
F::EwmMeanBy { .. } => FunctionOptions::length_preserving(),
1219
#[cfg(feature = "replace")]
1220
F::Replace => FunctionOptions::elementwise(),
1221
#[cfg(feature = "replace")]
1222
F::ReplaceStrict { .. } => FunctionOptions::elementwise(),
1223
F::GatherEvery { .. } => FunctionOptions::groupwise(),
1224
#[cfg(feature = "reinterpret")]
1225
F::Reinterpret(_) => FunctionOptions::elementwise(),
1226
F::ExtendConstant => FunctionOptions::groupwise(),
1227
1228
F::RowEncode(..) => FunctionOptions::elementwise(),
1229
#[cfg(feature = "dtype-struct")]
1230
F::RowDecode(..) => FunctionOptions::elementwise(),
1231
}
1232
}
1233
}
1234
1235