Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/dispatch/mod.rs
8393 views
1
use std::sync::Arc;
2
3
use polars_core::error::PolarsResult;
4
use polars_core::frame::DataFrame;
5
use polars_core::prelude::{Column, GroupPositions};
6
use polars_plan::dsl::{ColumnsUdf, SpecialEq};
7
use polars_plan::plans::{IRBooleanFunction, IRFunctionExpr, IRPowFunction};
8
use polars_utils::IdxSize;
9
10
use crate::prelude::{AggregationContext, PhysicalExpr};
11
use crate::state::ExecutionState;
12
13
#[macro_export]
14
macro_rules! wrap {
15
($e:expr) => {
16
SpecialEq::new(Arc::new($e))
17
};
18
19
($e:expr, $($args:expr),*) => {{
20
let f = move |s: &mut [::polars_core::prelude::Column]| {
21
$e(s, $($args),*)
22
};
23
24
SpecialEq::new(Arc::new(f))
25
}};
26
}
27
28
/// `Fn(&[Column], args)`
29
/// * all expression arguments are in the slice.
30
/// * the first element is the root expression.
31
#[macro_export]
32
macro_rules! map_as_slice {
33
($func:path) => {{
34
let f = move |s: &mut [::polars_core::prelude::Column]| {
35
$func(s)
36
};
37
38
SpecialEq::new(Arc::new(f))
39
}};
40
41
($func:path, $($args:expr),*) => {{
42
let f = move |s: &mut [::polars_core::prelude::Column]| {
43
$func(s, $($args),*)
44
};
45
46
SpecialEq::new(Arc::new(f))
47
}};
48
}
49
50
/// * `FnOnce(Series)`
51
/// * `FnOnce(Series, args)`
52
#[macro_export]
53
macro_rules! map_owned {
54
($func:path) => {{
55
let f = move |c: &mut [::polars_core::prelude::Column]| {
56
let c = std::mem::take(&mut c[0]);
57
$func(c)
58
};
59
60
SpecialEq::new(Arc::new(f))
61
}};
62
63
($func:path, $($args:expr),*) => {{
64
let f = move |c: &mut [::polars_core::prelude::Column]| {
65
let c = std::mem::take(&mut c[0]);
66
$func(c, $($args),*)
67
};
68
69
SpecialEq::new(Arc::new(f))
70
}};
71
}
72
73
/// `Fn(&Series, args)`
74
#[macro_export]
75
macro_rules! map {
76
($func:path) => {{
77
let f = move |c: &mut [::polars_core::prelude::Column]| {
78
let c = &c[0];
79
$func(c)
80
};
81
82
SpecialEq::new(Arc::new(f))
83
}};
84
85
($func:path, $($args:expr),*) => {{
86
let f = move |c: &mut [::polars_core::prelude::Column]| {
87
let c = &c[0];
88
$func(c, $($args),*)
89
};
90
91
SpecialEq::new(Arc::new(f))
92
}};
93
}
94
95
#[cfg(feature = "dtype-array")]
96
mod array;
97
mod binary;
98
#[cfg(feature = "bitwise")]
99
mod bitwise;
100
mod boolean;
101
#[cfg(feature = "business")]
102
mod business;
103
#[cfg(feature = "dtype-categorical")]
104
mod cat;
105
#[cfg(feature = "cum_agg")]
106
mod cum;
107
#[cfg(feature = "temporal")]
108
mod datetime;
109
#[cfg(feature = "dtype-extension")]
110
mod extension;
111
mod groups_dispatch;
112
mod horizontal;
113
mod list;
114
mod misc;
115
mod pow;
116
#[cfg(feature = "random")]
117
mod random;
118
#[cfg(feature = "range")]
119
mod range;
120
#[cfg(feature = "rolling_window")]
121
mod rolling;
122
#[cfg(feature = "rolling_window_by")]
123
mod rolling_by;
124
#[cfg(feature = "round_series")]
125
mod round;
126
mod shift_and_fill;
127
#[cfg(feature = "strings")]
128
mod strings;
129
#[cfg(feature = "dtype-struct")]
130
pub(crate) mod struct_;
131
#[cfg(feature = "temporal")]
132
mod temporal;
133
#[cfg(feature = "trigonometry")]
134
mod trigonometry;
135
136
pub use groups_dispatch::drop_items;
137
138
pub fn function_expr_to_udf(func: IRFunctionExpr) -> SpecialEq<Arc<dyn ColumnsUdf>> {
139
use IRFunctionExpr as F;
140
match func {
141
// Namespaces
142
#[cfg(feature = "dtype-array")]
143
F::ArrayExpr(func) => array::function_expr_to_udf(func),
144
F::BinaryExpr(func) => binary::function_expr_to_udf(func),
145
#[cfg(feature = "dtype-categorical")]
146
F::Categorical(func) => cat::function_expr_to_udf(func),
147
#[cfg(feature = "dtype-extension")]
148
F::Extension(func) => extension::function_expr_to_udf(func),
149
F::ListExpr(func) => list::function_expr_to_udf(func),
150
#[cfg(feature = "strings")]
151
F::StringExpr(func) => strings::function_expr_to_udf(func),
152
#[cfg(feature = "dtype-struct")]
153
F::StructExpr(func) => struct_::function_expr_to_udf(func),
154
#[cfg(feature = "temporal")]
155
F::TemporalExpr(func) => temporal::temporal_func_to_udf(func),
156
#[cfg(feature = "bitwise")]
157
F::Bitwise(func) => bitwise::function_expr_to_udf(func),
158
159
// Other expressions
160
F::Boolean(func) => boolean::function_expr_to_udf(func),
161
#[cfg(feature = "business")]
162
F::Business(func) => business::function_expr_to_udf(func),
163
#[cfg(feature = "abs")]
164
F::Abs => map!(misc::abs),
165
F::Negate => map!(misc::negate),
166
F::NullCount => {
167
let f = |s: &mut [Column]| {
168
let s = &s[0];
169
Ok(Column::new(s.name().clone(), [s.null_count() as IdxSize]))
170
};
171
wrap!(f)
172
},
173
F::Pow(func) => match func {
174
IRPowFunction::Generic => wrap!(pow::pow),
175
IRPowFunction::Sqrt => map!(pow::sqrt),
176
IRPowFunction::Cbrt => map!(pow::cbrt),
177
},
178
#[cfg(feature = "row_hash")]
179
F::Hash(k0, k1, k2, k3) => {
180
map!(misc::row_hash, k0, k1, k2, k3)
181
},
182
#[cfg(feature = "arg_where")]
183
F::ArgWhere => {
184
wrap!(misc::arg_where)
185
},
186
#[cfg(feature = "index_of")]
187
F::IndexOf => {
188
map_as_slice!(misc::index_of)
189
},
190
#[cfg(feature = "search_sorted")]
191
F::SearchSorted { side, descending } => {
192
map_as_slice!(misc::search_sorted_impl, side, descending)
193
},
194
#[cfg(feature = "range")]
195
F::Range(func) => range::function_expr_to_udf(func),
196
197
#[cfg(feature = "trigonometry")]
198
F::Trigonometry(trig_function) => {
199
map!(trigonometry::apply_trigonometric_function, trig_function)
200
},
201
#[cfg(feature = "trigonometry")]
202
F::Atan2 => {
203
wrap!(trigonometry::apply_arctan2)
204
},
205
206
#[cfg(feature = "sign")]
207
F::Sign => {
208
map!(misc::sign)
209
},
210
F::FillNull => {
211
map_as_slice!(misc::fill_null)
212
},
213
#[cfg(feature = "rolling_window")]
214
F::RollingExpr { function, options } => {
215
use IRRollingFunction::*;
216
use polars_plan::plans::IRRollingFunction;
217
match function {
218
Min => map!(rolling::rolling_min, options.clone()),
219
Max => map!(rolling::rolling_max, options.clone()),
220
Mean => map!(rolling::rolling_mean, options.clone()),
221
Sum => map!(rolling::rolling_sum, options.clone()),
222
Quantile => map!(rolling::rolling_quantile, options.clone()),
223
Var => map!(rolling::rolling_var, options.clone()),
224
Std => map!(rolling::rolling_std, options.clone()),
225
Rank => map!(rolling::rolling_rank, options.clone()),
226
#[cfg(feature = "moment")]
227
Skew => map!(rolling::rolling_skew, options.clone()),
228
#[cfg(feature = "moment")]
229
Kurtosis => map!(rolling::rolling_kurtosis, options.clone()),
230
#[cfg(feature = "cov")]
231
CorrCov {
232
corr_cov_options,
233
is_corr,
234
} => {
235
map_as_slice!(
236
rolling::rolling_corr_cov,
237
options.clone(),
238
corr_cov_options,
239
is_corr
240
)
241
},
242
Map(f) => {
243
map!(rolling::rolling_map, options.clone(), f.clone())
244
},
245
}
246
},
247
#[cfg(feature = "rolling_window_by")]
248
F::RollingExprBy {
249
function_by,
250
options,
251
} => {
252
use IRRollingFunctionBy::*;
253
use polars_plan::plans::IRRollingFunctionBy;
254
match function_by {
255
MinBy => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
256
MaxBy => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
257
MeanBy => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
258
SumBy => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
259
QuantileBy => {
260
map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
261
},
262
VarBy => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
263
StdBy => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
264
RankBy => map_as_slice!(rolling_by::rolling_rank_by, options.clone()),
265
}
266
},
267
#[cfg(feature = "hist")]
268
F::Hist {
269
bin_count,
270
include_category,
271
include_breakpoint,
272
} => {
273
map_as_slice!(misc::hist, bin_count, include_category, include_breakpoint)
274
},
275
F::Rechunk => map!(misc::rechunk),
276
F::Append { upcast } => map_as_slice!(misc::append, upcast),
277
F::ShiftAndFill => {
278
map_as_slice!(shift_and_fill::shift_and_fill)
279
},
280
F::DropNans => map_owned!(misc::drop_nans),
281
F::DropNulls => map!(misc::drop_nulls),
282
#[cfg(feature = "round_series")]
283
F::Clip { has_min, has_max } => {
284
map_as_slice!(misc::clip, has_min, has_max)
285
},
286
#[cfg(feature = "mode")]
287
F::Mode { maintain_order } => map!(misc::mode, maintain_order),
288
#[cfg(feature = "moment")]
289
F::Skew(bias) => map!(misc::skew, bias),
290
#[cfg(feature = "moment")]
291
F::Kurtosis(fisher, bias) => map!(misc::kurtosis, fisher, bias),
292
F::ArgUnique => map!(misc::arg_unique),
293
F::ArgMin => map!(misc::arg_min),
294
F::ArgMax => map!(misc::arg_max),
295
F::ArgSort {
296
descending,
297
nulls_last,
298
} => map!(misc::arg_sort, descending, nulls_last),
299
F::MinBy => map_as_slice!(misc::min_by),
300
F::MaxBy => map_as_slice!(misc::max_by),
301
F::Product => map!(misc::product),
302
F::Repeat => map_as_slice!(misc::repeat),
303
#[cfg(feature = "rank")]
304
F::Rank { options, seed } => map!(misc::rank, options, seed),
305
#[cfg(feature = "dtype-struct")]
306
F::AsStruct => {
307
map_as_slice!(misc::as_struct)
308
},
309
#[cfg(feature = "top_k")]
310
F::TopK { descending } => {
311
map_as_slice!(polars_ops::prelude::top_k, descending)
312
},
313
#[cfg(feature = "top_k")]
314
F::TopKBy { descending } => {
315
map_as_slice!(polars_ops::prelude::top_k_by, descending.clone())
316
},
317
F::Shift => map_as_slice!(shift_and_fill::shift),
318
#[cfg(feature = "cum_agg")]
319
F::CumCount { reverse } => map!(cum::cum_count, reverse),
320
#[cfg(feature = "cum_agg")]
321
F::CumSum { reverse } => map!(cum::cum_sum, reverse),
322
#[cfg(feature = "cum_agg")]
323
F::CumProd { reverse } => map!(cum::cum_prod, reverse),
324
#[cfg(feature = "cum_agg")]
325
F::CumMin { reverse } => map!(cum::cum_min, reverse),
326
#[cfg(feature = "cum_agg")]
327
F::CumMax { reverse } => map!(cum::cum_max, reverse),
328
#[cfg(feature = "dtype-struct")]
329
F::ValueCounts {
330
sort,
331
parallel,
332
name,
333
normalize,
334
} => map!(misc::value_counts, sort, parallel, name.clone(), normalize),
335
#[cfg(feature = "unique_counts")]
336
F::UniqueCounts => map!(misc::unique_counts),
337
F::Reverse => map!(misc::reverse),
338
#[cfg(feature = "approx_unique")]
339
F::ApproxNUnique => map!(misc::approx_n_unique),
340
F::Coalesce => map_as_slice!(misc::coalesce),
341
#[cfg(feature = "diff")]
342
F::Diff(null_behavior) => map_as_slice!(misc::diff, null_behavior),
343
#[cfg(feature = "pct_change")]
344
F::PctChange => map_as_slice!(misc::pct_change),
345
#[cfg(feature = "interpolate")]
346
F::Interpolate(method) => {
347
map!(misc::interpolate, method)
348
},
349
#[cfg(feature = "interpolate_by")]
350
F::InterpolateBy => {
351
map_as_slice!(misc::interpolate_by)
352
},
353
#[cfg(feature = "log")]
354
F::Entropy { base, normalize } => map!(misc::entropy, base, normalize),
355
#[cfg(feature = "log")]
356
F::Log => map_as_slice!(misc::log),
357
#[cfg(feature = "log")]
358
F::Log1p => map!(misc::log1p),
359
#[cfg(feature = "log")]
360
F::Exp => map!(misc::exp),
361
F::Unique(stable) => map!(misc::unique, stable),
362
#[cfg(feature = "round_series")]
363
F::Round { decimals, mode } => map!(round::round, decimals, mode),
364
#[cfg(feature = "round_series")]
365
F::RoundSF { digits } => map!(round::round_sig_figs, digits),
366
#[cfg(feature = "round_series")]
367
F::Floor => map!(round::floor),
368
#[cfg(feature = "round_series")]
369
F::Ceil => map!(round::ceil),
370
#[cfg(feature = "fused")]
371
F::Fused(op) => map_as_slice!(misc::fused, op),
372
F::ConcatExpr(rechunk) => map_as_slice!(misc::concat_expr, rechunk),
373
#[cfg(feature = "cov")]
374
F::Correlation { method } => map_as_slice!(misc::corr, method),
375
#[cfg(feature = "peaks")]
376
F::PeakMin => map!(misc::peak_min),
377
#[cfg(feature = "peaks")]
378
F::PeakMax => map!(misc::peak_max),
379
#[cfg(feature = "repeat_by")]
380
F::RepeatBy => map_as_slice!(misc::repeat_by),
381
#[cfg(feature = "dtype-array")]
382
F::Reshape(dims) => map!(misc::reshape, &dims),
383
#[cfg(feature = "cutqcut")]
384
F::Cut {
385
breaks,
386
labels,
387
left_closed,
388
include_breaks,
389
} => map!(
390
misc::cut,
391
breaks.clone(),
392
labels.clone(),
393
left_closed,
394
include_breaks
395
),
396
#[cfg(feature = "cutqcut")]
397
F::QCut {
398
probs,
399
labels,
400
left_closed,
401
allow_duplicates,
402
include_breaks,
403
} => map!(
404
misc::qcut,
405
probs.clone(),
406
labels.clone(),
407
left_closed,
408
allow_duplicates,
409
include_breaks
410
),
411
#[cfg(feature = "rle")]
412
F::RLE => map!(polars_ops::series::rle),
413
#[cfg(feature = "rle")]
414
F::RLEID => map!(polars_ops::series::rle_id),
415
F::ToPhysical => map!(misc::to_physical),
416
#[cfg(feature = "random")]
417
F::Random { method, seed } => {
418
use IRRandomMethod::*;
419
use polars_plan::plans::IRRandomMethod;
420
match method {
421
Shuffle => map!(random::shuffle, seed),
422
Sample {
423
is_fraction,
424
with_replacement,
425
shuffle,
426
} => {
427
if is_fraction {
428
map_as_slice!(random::sample_frac, with_replacement, shuffle, seed)
429
} else {
430
map_as_slice!(random::sample_n, with_replacement, shuffle, seed)
431
}
432
},
433
}
434
},
435
F::SetSortedFlag(sorted) => map!(misc::set_sorted_flag, sorted),
436
#[cfg(feature = "ffi_plugin")]
437
F::FfiPlugin {
438
flags: _,
439
lib,
440
symbol,
441
kwargs,
442
} => unsafe {
443
map_as_slice!(
444
polars_plan::plans::plugin::call_plugin,
445
lib.as_ref(),
446
symbol.as_ref(),
447
kwargs.as_ref()
448
)
449
},
450
451
F::FoldHorizontal {
452
callback,
453
returns_scalar,
454
return_dtype,
455
} => map_as_slice!(
456
horizontal::fold,
457
&callback,
458
returns_scalar,
459
return_dtype.as_ref()
460
),
461
F::ReduceHorizontal {
462
callback,
463
returns_scalar,
464
return_dtype,
465
} => map_as_slice!(
466
horizontal::reduce,
467
&callback,
468
returns_scalar,
469
return_dtype.as_ref()
470
),
471
#[cfg(feature = "dtype-struct")]
472
F::CumReduceHorizontal {
473
callback,
474
returns_scalar,
475
return_dtype,
476
} => map_as_slice!(
477
horizontal::cum_reduce,
478
&callback,
479
returns_scalar,
480
return_dtype.as_ref()
481
),
482
#[cfg(feature = "dtype-struct")]
483
F::CumFoldHorizontal {
484
callback,
485
returns_scalar,
486
return_dtype,
487
include_init,
488
} => map_as_slice!(
489
horizontal::cum_fold,
490
&callback,
491
returns_scalar,
492
return_dtype.as_ref(),
493
include_init
494
),
495
496
F::MaxHorizontal => wrap!(misc::max_horizontal),
497
F::MinHorizontal => wrap!(misc::min_horizontal),
498
F::SumHorizontal { ignore_nulls } => wrap!(misc::sum_horizontal, ignore_nulls),
499
F::MeanHorizontal { ignore_nulls } => wrap!(misc::mean_horizontal, ignore_nulls),
500
#[cfg(feature = "ewma")]
501
F::EwmMean { options } => map!(misc::ewm_mean, options),
502
#[cfg(feature = "ewma_by")]
503
F::EwmMeanBy { half_life } => map_as_slice!(misc::ewm_mean_by, half_life),
504
#[cfg(feature = "ewma")]
505
F::EwmStd { options } => map!(misc::ewm_std, options),
506
#[cfg(feature = "ewma")]
507
F::EwmVar { options } => map!(misc::ewm_var, options),
508
#[cfg(feature = "replace")]
509
F::Replace => {
510
map_as_slice!(misc::replace)
511
},
512
#[cfg(feature = "replace")]
513
F::ReplaceStrict { return_dtype } => {
514
map_as_slice!(misc::replace_strict, return_dtype.clone())
515
},
516
517
F::FillNullWithStrategy(strategy) => map!(misc::fill_null_with_strategy, strategy),
518
F::GatherEvery { n, offset } => map!(misc::gather_every, n, offset),
519
#[cfg(feature = "reinterpret")]
520
F::Reinterpret(signed) => map!(misc::reinterpret, signed),
521
F::ExtendConstant => map_as_slice!(misc::extend_constant),
522
523
F::RowEncode(dts, variants) => {
524
map_as_slice!(misc::row_encode, dts.clone(), variants.clone())
525
},
526
#[cfg(feature = "dtype-struct")]
527
F::RowDecode(fs, variants) => {
528
map_as_slice!(misc::row_decode, fs.clone(), variants.clone())
529
},
530
}
531
}
532
533
pub trait GroupsUdf: Send + Sync + 'static {
534
fn evaluate_on_groups<'a>(
535
&self,
536
inputs: &[Arc<dyn PhysicalExpr>],
537
df: &DataFrame,
538
groups: &'a GroupPositions,
539
state: &ExecutionState,
540
) -> PolarsResult<AggregationContext<'a>>;
541
}
542
543
pub fn function_expr_to_groups_udf(func: &IRFunctionExpr) -> Option<SpecialEq<Arc<dyn GroupsUdf>>> {
544
macro_rules! wrap_groups {
545
($f:expr$(, ($arg:expr, $n:ident:$ty:ty))*) => {{
546
struct Wrap($($ty),*);
547
impl GroupsUdf for Wrap {
548
fn evaluate_on_groups<'a>(
549
&self,
550
inputs: &[Arc<dyn PhysicalExpr>],
551
df: &DataFrame,
552
groups: &'a GroupPositions,
553
state: &ExecutionState,
554
) -> PolarsResult<AggregationContext<'a>> {
555
let Wrap($($n),*) = self;
556
$f(inputs, df, groups, state$(, *$n)*)
557
}
558
}
559
560
SpecialEq::new(Arc::new(Wrap($($arg),*)) as Arc<dyn GroupsUdf>)
561
}};
562
}
563
use IRFunctionExpr as F;
564
Some(match func {
565
F::NullCount => wrap_groups!(groups_dispatch::null_count),
566
F::Reverse => wrap_groups!(groups_dispatch::reverse),
567
F::Boolean(IRBooleanFunction::Any { ignore_nulls }) => {
568
let ignore_nulls = *ignore_nulls;
569
wrap_groups!(groups_dispatch::any, (ignore_nulls, v: bool))
570
},
571
F::Boolean(IRBooleanFunction::All { ignore_nulls }) => {
572
let ignore_nulls = *ignore_nulls;
573
wrap_groups!(groups_dispatch::all, (ignore_nulls, v: bool))
574
},
575
#[cfg(feature = "bitwise")]
576
F::Bitwise(f) => {
577
use polars_plan::plans::IRBitwiseFunction as B;
578
match f {
579
B::And => wrap_groups!(groups_dispatch::bitwise_and),
580
B::Or => wrap_groups!(groups_dispatch::bitwise_or),
581
B::Xor => wrap_groups!(groups_dispatch::bitwise_xor),
582
_ => return None,
583
}
584
},
585
F::DropNans => wrap_groups!(groups_dispatch::drop_nans),
586
F::DropNulls => wrap_groups!(groups_dispatch::drop_nulls),
587
588
#[cfg(feature = "moment")]
589
F::Skew(bias) => wrap_groups!(groups_dispatch::skew, (*bias, v: bool)),
590
#[cfg(feature = "moment")]
591
F::Kurtosis(fisher, bias) => {
592
wrap_groups!(groups_dispatch::kurtosis, (*fisher, v1: bool), (*bias, v2: bool))
593
},
594
595
F::Unique(stable) => wrap_groups!(groups_dispatch::unique, (*stable, v: bool)),
596
F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Forward(limit)) => {
597
wrap_groups!(groups_dispatch::forward_fill_null, (*limit, v: Option<IdxSize>))
598
},
599
F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Backward(limit)) => {
600
wrap_groups!(groups_dispatch::backward_fill_null, (*limit, v: Option<IdxSize>))
601
},
602
603
_ => return None,
604
})
605
}
606
607