Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/dispatch/mod.rs
7884 views
1
use std::sync::Arc;
2
3
use polars_core::error::PolarsResult;
4
use polars_core::frame::DataFrame;
5
use polars_core::prelude::{Column, GroupPositions};
6
use polars_plan::dsl::{ColumnsUdf, SpecialEq};
7
use polars_plan::plans::{IRBooleanFunction, IRFunctionExpr, IRPowFunction};
8
use polars_utils::IdxSize;
9
10
use crate::prelude::{AggregationContext, PhysicalExpr};
11
use crate::state::ExecutionState;
12
13
#[macro_export]
14
macro_rules! wrap {
15
($e:expr) => {
16
SpecialEq::new(Arc::new($e))
17
};
18
19
($e:expr, $($args:expr),*) => {{
20
let f = move |s: &mut [::polars_core::prelude::Column]| {
21
$e(s, $($args),*)
22
};
23
24
SpecialEq::new(Arc::new(f))
25
}};
26
}
27
28
/// `Fn(&[Column], args)`
29
/// * all expression arguments are in the slice.
30
/// * the first element is the root expression.
31
#[macro_export]
32
macro_rules! map_as_slice {
33
($func:path) => {{
34
let f = move |s: &mut [::polars_core::prelude::Column]| {
35
$func(s)
36
};
37
38
SpecialEq::new(Arc::new(f))
39
}};
40
41
($func:path, $($args:expr),*) => {{
42
let f = move |s: &mut [::polars_core::prelude::Column]| {
43
$func(s, $($args),*)
44
};
45
46
SpecialEq::new(Arc::new(f))
47
}};
48
}
49
50
/// * `FnOnce(Series)`
51
/// * `FnOnce(Series, args)`
52
#[macro_export]
53
macro_rules! map_owned {
54
($func:path) => {{
55
let f = move |c: &mut [::polars_core::prelude::Column]| {
56
let c = std::mem::take(&mut c[0]);
57
$func(c)
58
};
59
60
SpecialEq::new(Arc::new(f))
61
}};
62
63
($func:path, $($args:expr),*) => {{
64
let f = move |c: &mut [::polars_core::prelude::Column]| {
65
let c = std::mem::take(&mut c[0]);
66
$func(c, $($args),*)
67
};
68
69
SpecialEq::new(Arc::new(f))
70
}};
71
}
72
73
/// `Fn(&Series, args)`
74
#[macro_export]
75
macro_rules! map {
76
($func:path) => {{
77
let f = move |c: &mut [::polars_core::prelude::Column]| {
78
let c = &c[0];
79
$func(c)
80
};
81
82
SpecialEq::new(Arc::new(f))
83
}};
84
85
($func:path, $($args:expr),*) => {{
86
let f = move |c: &mut [::polars_core::prelude::Column]| {
87
let c = &c[0];
88
$func(c, $($args),*)
89
};
90
91
SpecialEq::new(Arc::new(f))
92
}};
93
}
94
95
#[cfg(feature = "dtype-array")]
96
mod array;
97
mod binary;
98
#[cfg(feature = "bitwise")]
99
mod bitwise;
100
mod boolean;
101
#[cfg(feature = "business")]
102
mod business;
103
#[cfg(feature = "dtype-categorical")]
104
mod cat;
105
#[cfg(feature = "cum_agg")]
106
mod cum;
107
#[cfg(feature = "temporal")]
108
mod datetime;
109
#[cfg(feature = "dtype-extension")]
110
mod extension;
111
mod groups_dispatch;
112
mod horizontal;
113
mod list;
114
mod misc;
115
mod pow;
116
#[cfg(feature = "random")]
117
mod random;
118
#[cfg(feature = "range")]
119
mod range;
120
#[cfg(feature = "rolling_window")]
121
mod rolling;
122
#[cfg(feature = "rolling_window_by")]
123
mod rolling_by;
124
#[cfg(feature = "round_series")]
125
mod round;
126
mod shift_and_fill;
127
#[cfg(feature = "strings")]
128
mod strings;
129
#[cfg(feature = "dtype-struct")]
130
mod struct_;
131
#[cfg(feature = "temporal")]
132
mod temporal;
133
#[cfg(feature = "trigonometry")]
134
mod trigonometry;
135
136
pub use groups_dispatch::drop_items;
137
138
pub fn function_expr_to_udf(func: IRFunctionExpr) -> SpecialEq<Arc<dyn ColumnsUdf>> {
139
use IRFunctionExpr as F;
140
match func {
141
// Namespaces
142
#[cfg(feature = "dtype-array")]
143
F::ArrayExpr(func) => array::function_expr_to_udf(func),
144
F::BinaryExpr(func) => binary::function_expr_to_udf(func),
145
#[cfg(feature = "dtype-categorical")]
146
F::Categorical(func) => cat::function_expr_to_udf(func),
147
#[cfg(feature = "dtype-extension")]
148
F::Extension(func) => extension::function_expr_to_udf(func),
149
F::ListExpr(func) => list::function_expr_to_udf(func),
150
#[cfg(feature = "strings")]
151
F::StringExpr(func) => strings::function_expr_to_udf(func),
152
#[cfg(feature = "dtype-struct")]
153
F::StructExpr(func) => struct_::function_expr_to_udf(func),
154
#[cfg(feature = "temporal")]
155
F::TemporalExpr(func) => temporal::temporal_func_to_udf(func),
156
#[cfg(feature = "bitwise")]
157
F::Bitwise(func) => bitwise::function_expr_to_udf(func),
158
159
// Other expressions
160
F::Boolean(func) => boolean::function_expr_to_udf(func),
161
#[cfg(feature = "business")]
162
F::Business(func) => business::function_expr_to_udf(func),
163
#[cfg(feature = "abs")]
164
F::Abs => map!(misc::abs),
165
F::Negate => map!(misc::negate),
166
F::NullCount => {
167
let f = |s: &mut [Column]| {
168
let s = &s[0];
169
Ok(Column::new(s.name().clone(), [s.null_count() as IdxSize]))
170
};
171
wrap!(f)
172
},
173
F::Pow(func) => match func {
174
IRPowFunction::Generic => wrap!(pow::pow),
175
IRPowFunction::Sqrt => map!(pow::sqrt),
176
IRPowFunction::Cbrt => map!(pow::cbrt),
177
},
178
#[cfg(feature = "row_hash")]
179
F::Hash(k0, k1, k2, k3) => {
180
map!(misc::row_hash, k0, k1, k2, k3)
181
},
182
#[cfg(feature = "arg_where")]
183
F::ArgWhere => {
184
wrap!(misc::arg_where)
185
},
186
#[cfg(feature = "index_of")]
187
F::IndexOf => {
188
map_as_slice!(misc::index_of)
189
},
190
#[cfg(feature = "search_sorted")]
191
F::SearchSorted { side, descending } => {
192
map_as_slice!(misc::search_sorted_impl, side, descending)
193
},
194
#[cfg(feature = "range")]
195
F::Range(func) => range::function_expr_to_udf(func),
196
197
#[cfg(feature = "trigonometry")]
198
F::Trigonometry(trig_function) => {
199
map!(trigonometry::apply_trigonometric_function, trig_function)
200
},
201
#[cfg(feature = "trigonometry")]
202
F::Atan2 => {
203
wrap!(trigonometry::apply_arctan2)
204
},
205
206
#[cfg(feature = "sign")]
207
F::Sign => {
208
map!(misc::sign)
209
},
210
F::FillNull => {
211
map_as_slice!(misc::fill_null)
212
},
213
#[cfg(feature = "rolling_window")]
214
F::RollingExpr { function, options } => {
215
use IRRollingFunction::*;
216
use polars_plan::plans::IRRollingFunction;
217
match function {
218
Min => map!(rolling::rolling_min, options.clone()),
219
Max => map!(rolling::rolling_max, options.clone()),
220
Mean => map!(rolling::rolling_mean, options.clone()),
221
Sum => map!(rolling::rolling_sum, options.clone()),
222
Quantile => map!(rolling::rolling_quantile, options.clone()),
223
Var => map!(rolling::rolling_var, options.clone()),
224
Std => map!(rolling::rolling_std, options.clone()),
225
Rank => map!(rolling::rolling_rank, options.clone()),
226
#[cfg(feature = "moment")]
227
Skew => map!(rolling::rolling_skew, options.clone()),
228
#[cfg(feature = "moment")]
229
Kurtosis => map!(rolling::rolling_kurtosis, options.clone()),
230
#[cfg(feature = "cov")]
231
CorrCov {
232
corr_cov_options,
233
is_corr,
234
} => {
235
map_as_slice!(
236
rolling::rolling_corr_cov,
237
options.clone(),
238
corr_cov_options,
239
is_corr
240
)
241
},
242
Map(f) => {
243
map!(rolling::rolling_map, options.clone(), f.clone())
244
},
245
}
246
},
247
#[cfg(feature = "rolling_window_by")]
248
F::RollingExprBy {
249
function_by,
250
options,
251
} => {
252
use IRRollingFunctionBy::*;
253
use polars_plan::plans::IRRollingFunctionBy;
254
match function_by {
255
MinBy => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
256
MaxBy => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
257
MeanBy => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
258
SumBy => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
259
QuantileBy => {
260
map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
261
},
262
VarBy => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
263
StdBy => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
264
RankBy => map_as_slice!(rolling_by::rolling_rank_by, options.clone()),
265
}
266
},
267
#[cfg(feature = "hist")]
268
F::Hist {
269
bin_count,
270
include_category,
271
include_breakpoint,
272
} => {
273
map_as_slice!(misc::hist, bin_count, include_category, include_breakpoint)
274
},
275
F::Rechunk => map!(misc::rechunk),
276
F::Append { upcast } => map_as_slice!(misc::append, upcast),
277
F::ShiftAndFill => {
278
map_as_slice!(shift_and_fill::shift_and_fill)
279
},
280
F::DropNans => map_owned!(misc::drop_nans),
281
F::DropNulls => map!(misc::drop_nulls),
282
#[cfg(feature = "round_series")]
283
F::Clip { has_min, has_max } => {
284
map_as_slice!(misc::clip, has_min, has_max)
285
},
286
#[cfg(feature = "mode")]
287
F::Mode { maintain_order } => map!(misc::mode, maintain_order),
288
#[cfg(feature = "moment")]
289
F::Skew(bias) => map!(misc::skew, bias),
290
#[cfg(feature = "moment")]
291
F::Kurtosis(fisher, bias) => map!(misc::kurtosis, fisher, bias),
292
F::ArgUnique => map!(misc::arg_unique),
293
F::ArgMin => map!(misc::arg_min),
294
F::ArgMax => map!(misc::arg_max),
295
F::ArgSort {
296
descending,
297
nulls_last,
298
} => map!(misc::arg_sort, descending, nulls_last),
299
F::Product => map!(misc::product),
300
F::Repeat => map_as_slice!(misc::repeat),
301
#[cfg(feature = "rank")]
302
F::Rank { options, seed } => map!(misc::rank, options, seed),
303
#[cfg(feature = "dtype-struct")]
304
F::AsStruct => {
305
map_as_slice!(misc::as_struct)
306
},
307
#[cfg(feature = "top_k")]
308
F::TopK { descending } => {
309
map_as_slice!(polars_ops::prelude::top_k, descending)
310
},
311
#[cfg(feature = "top_k")]
312
F::TopKBy { descending } => {
313
map_as_slice!(polars_ops::prelude::top_k_by, descending.clone())
314
},
315
F::Shift => map_as_slice!(shift_and_fill::shift),
316
#[cfg(feature = "cum_agg")]
317
F::CumCount { reverse } => map!(cum::cum_count, reverse),
318
#[cfg(feature = "cum_agg")]
319
F::CumSum { reverse } => map!(cum::cum_sum, reverse),
320
#[cfg(feature = "cum_agg")]
321
F::CumProd { reverse } => map!(cum::cum_prod, reverse),
322
#[cfg(feature = "cum_agg")]
323
F::CumMin { reverse } => map!(cum::cum_min, reverse),
324
#[cfg(feature = "cum_agg")]
325
F::CumMax { reverse } => map!(cum::cum_max, reverse),
326
#[cfg(feature = "dtype-struct")]
327
F::ValueCounts {
328
sort,
329
parallel,
330
name,
331
normalize,
332
} => map!(misc::value_counts, sort, parallel, name.clone(), normalize),
333
#[cfg(feature = "unique_counts")]
334
F::UniqueCounts => map!(misc::unique_counts),
335
F::Reverse => map!(misc::reverse),
336
#[cfg(feature = "approx_unique")]
337
F::ApproxNUnique => map!(misc::approx_n_unique),
338
F::Coalesce => map_as_slice!(misc::coalesce),
339
#[cfg(feature = "diff")]
340
F::Diff(null_behavior) => map_as_slice!(misc::diff, null_behavior),
341
#[cfg(feature = "pct_change")]
342
F::PctChange => map_as_slice!(misc::pct_change),
343
#[cfg(feature = "interpolate")]
344
F::Interpolate(method) => {
345
map!(misc::interpolate, method)
346
},
347
#[cfg(feature = "interpolate_by")]
348
F::InterpolateBy => {
349
map_as_slice!(misc::interpolate_by)
350
},
351
#[cfg(feature = "log")]
352
F::Entropy { base, normalize } => map!(misc::entropy, base, normalize),
353
#[cfg(feature = "log")]
354
F::Log => map_as_slice!(misc::log),
355
#[cfg(feature = "log")]
356
F::Log1p => map!(misc::log1p),
357
#[cfg(feature = "log")]
358
F::Exp => map!(misc::exp),
359
F::Unique(stable) => map!(misc::unique, stable),
360
#[cfg(feature = "round_series")]
361
F::Round { decimals, mode } => map!(round::round, decimals, mode),
362
#[cfg(feature = "round_series")]
363
F::RoundSF { digits } => map!(round::round_sig_figs, digits),
364
#[cfg(feature = "round_series")]
365
F::Floor => map!(round::floor),
366
#[cfg(feature = "round_series")]
367
F::Ceil => map!(round::ceil),
368
#[cfg(feature = "fused")]
369
F::Fused(op) => map_as_slice!(misc::fused, op),
370
F::ConcatExpr(rechunk) => map_as_slice!(misc::concat_expr, rechunk),
371
#[cfg(feature = "cov")]
372
F::Correlation { method } => map_as_slice!(misc::corr, method),
373
#[cfg(feature = "peaks")]
374
F::PeakMin => map!(misc::peak_min),
375
#[cfg(feature = "peaks")]
376
F::PeakMax => map!(misc::peak_max),
377
#[cfg(feature = "repeat_by")]
378
F::RepeatBy => map_as_slice!(misc::repeat_by),
379
#[cfg(feature = "dtype-array")]
380
F::Reshape(dims) => map!(misc::reshape, &dims),
381
#[cfg(feature = "cutqcut")]
382
F::Cut {
383
breaks,
384
labels,
385
left_closed,
386
include_breaks,
387
} => map!(
388
misc::cut,
389
breaks.clone(),
390
labels.clone(),
391
left_closed,
392
include_breaks
393
),
394
#[cfg(feature = "cutqcut")]
395
F::QCut {
396
probs,
397
labels,
398
left_closed,
399
allow_duplicates,
400
include_breaks,
401
} => map!(
402
misc::qcut,
403
probs.clone(),
404
labels.clone(),
405
left_closed,
406
allow_duplicates,
407
include_breaks
408
),
409
#[cfg(feature = "rle")]
410
F::RLE => map!(polars_ops::series::rle),
411
#[cfg(feature = "rle")]
412
F::RLEID => map!(polars_ops::series::rle_id),
413
F::ToPhysical => map!(misc::to_physical),
414
#[cfg(feature = "random")]
415
F::Random { method, seed } => {
416
use IRRandomMethod::*;
417
use polars_plan::plans::IRRandomMethod;
418
match method {
419
Shuffle => map!(random::shuffle, seed),
420
Sample {
421
is_fraction,
422
with_replacement,
423
shuffle,
424
} => {
425
if is_fraction {
426
map_as_slice!(random::sample_frac, with_replacement, shuffle, seed)
427
} else {
428
map_as_slice!(random::sample_n, with_replacement, shuffle, seed)
429
}
430
},
431
}
432
},
433
F::SetSortedFlag(sorted) => map!(misc::set_sorted_flag, sorted),
434
#[cfg(feature = "ffi_plugin")]
435
F::FfiPlugin {
436
flags: _,
437
lib,
438
symbol,
439
kwargs,
440
} => unsafe {
441
map_as_slice!(
442
polars_plan::plans::plugin::call_plugin,
443
lib.as_ref(),
444
symbol.as_ref(),
445
kwargs.as_ref()
446
)
447
},
448
449
F::FoldHorizontal {
450
callback,
451
returns_scalar,
452
return_dtype,
453
} => map_as_slice!(
454
horizontal::fold,
455
&callback,
456
returns_scalar,
457
return_dtype.as_ref()
458
),
459
F::ReduceHorizontal {
460
callback,
461
returns_scalar,
462
return_dtype,
463
} => map_as_slice!(
464
horizontal::reduce,
465
&callback,
466
returns_scalar,
467
return_dtype.as_ref()
468
),
469
#[cfg(feature = "dtype-struct")]
470
F::CumReduceHorizontal {
471
callback,
472
returns_scalar,
473
return_dtype,
474
} => map_as_slice!(
475
horizontal::cum_reduce,
476
&callback,
477
returns_scalar,
478
return_dtype.as_ref()
479
),
480
#[cfg(feature = "dtype-struct")]
481
F::CumFoldHorizontal {
482
callback,
483
returns_scalar,
484
return_dtype,
485
include_init,
486
} => map_as_slice!(
487
horizontal::cum_fold,
488
&callback,
489
returns_scalar,
490
return_dtype.as_ref(),
491
include_init
492
),
493
494
F::MaxHorizontal => wrap!(misc::max_horizontal),
495
F::MinHorizontal => wrap!(misc::min_horizontal),
496
F::SumHorizontal { ignore_nulls } => wrap!(misc::sum_horizontal, ignore_nulls),
497
F::MeanHorizontal { ignore_nulls } => wrap!(misc::mean_horizontal, ignore_nulls),
498
#[cfg(feature = "ewma")]
499
F::EwmMean { options } => map!(misc::ewm_mean, options),
500
#[cfg(feature = "ewma_by")]
501
F::EwmMeanBy { half_life } => map_as_slice!(misc::ewm_mean_by, half_life),
502
#[cfg(feature = "ewma")]
503
F::EwmStd { options } => map!(misc::ewm_std, options),
504
#[cfg(feature = "ewma")]
505
F::EwmVar { options } => map!(misc::ewm_var, options),
506
#[cfg(feature = "replace")]
507
F::Replace => {
508
map_as_slice!(misc::replace)
509
},
510
#[cfg(feature = "replace")]
511
F::ReplaceStrict { return_dtype } => {
512
map_as_slice!(misc::replace_strict, return_dtype.clone())
513
},
514
515
F::FillNullWithStrategy(strategy) => map!(misc::fill_null_with_strategy, strategy),
516
F::GatherEvery { n, offset } => map!(misc::gather_every, n, offset),
517
#[cfg(feature = "reinterpret")]
518
F::Reinterpret(signed) => map!(misc::reinterpret, signed),
519
F::ExtendConstant => map_as_slice!(misc::extend_constant),
520
521
F::RowEncode(dts, variants) => {
522
map_as_slice!(misc::row_encode, dts.clone(), variants.clone())
523
},
524
#[cfg(feature = "dtype-struct")]
525
F::RowDecode(fs, variants) => {
526
map_as_slice!(misc::row_decode, fs.clone(), variants.clone())
527
},
528
}
529
}
530
531
pub trait GroupsUdf: Send + Sync + 'static {
532
fn evaluate_on_groups<'a>(
533
&self,
534
inputs: &[Arc<dyn PhysicalExpr>],
535
df: &DataFrame,
536
groups: &'a GroupPositions,
537
state: &ExecutionState,
538
) -> PolarsResult<AggregationContext<'a>>;
539
}
540
541
pub fn function_expr_to_groups_udf(func: &IRFunctionExpr) -> Option<SpecialEq<Arc<dyn GroupsUdf>>> {
542
macro_rules! wrap_groups {
543
($f:expr$(, ($arg:expr, $n:ident:$ty:ty))*) => {{
544
struct Wrap($($ty),*);
545
impl GroupsUdf for Wrap {
546
fn evaluate_on_groups<'a>(
547
&self,
548
inputs: &[Arc<dyn PhysicalExpr>],
549
df: &DataFrame,
550
groups: &'a GroupPositions,
551
state: &ExecutionState,
552
) -> PolarsResult<AggregationContext<'a>> {
553
let Wrap($($n),*) = self;
554
$f(inputs, df, groups, state$(, *$n)*)
555
}
556
}
557
558
SpecialEq::new(Arc::new(Wrap($($arg),*)) as Arc<dyn GroupsUdf>)
559
}};
560
}
561
use IRFunctionExpr as F;
562
Some(match func {
563
F::NullCount => wrap_groups!(groups_dispatch::null_count),
564
F::Reverse => wrap_groups!(groups_dispatch::reverse),
565
F::Boolean(IRBooleanFunction::Any { ignore_nulls }) => {
566
let ignore_nulls = *ignore_nulls;
567
wrap_groups!(groups_dispatch::any, (ignore_nulls, v: bool))
568
},
569
F::Boolean(IRBooleanFunction::All { ignore_nulls }) => {
570
let ignore_nulls = *ignore_nulls;
571
wrap_groups!(groups_dispatch::all, (ignore_nulls, v: bool))
572
},
573
#[cfg(feature = "bitwise")]
574
F::Bitwise(f) => {
575
use polars_plan::plans::IRBitwiseFunction as B;
576
match f {
577
B::And => wrap_groups!(groups_dispatch::bitwise_and),
578
B::Or => wrap_groups!(groups_dispatch::bitwise_or),
579
B::Xor => wrap_groups!(groups_dispatch::bitwise_xor),
580
_ => return None,
581
}
582
},
583
F::DropNans => wrap_groups!(groups_dispatch::drop_nans),
584
F::DropNulls => wrap_groups!(groups_dispatch::drop_nulls),
585
586
#[cfg(feature = "moment")]
587
F::Skew(bias) => wrap_groups!(groups_dispatch::skew, (*bias, v: bool)),
588
#[cfg(feature = "moment")]
589
F::Kurtosis(fisher, bias) => {
590
wrap_groups!(groups_dispatch::kurtosis, (*fisher, v1: bool), (*bias, v2: bool))
591
},
592
593
F::Unique(stable) => wrap_groups!(groups_dispatch::unique, (*stable, v: bool)),
594
F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Forward(limit)) => {
595
wrap_groups!(groups_dispatch::forward_fill_null, (*limit, v: Option<IdxSize>))
596
},
597
F::FillNullWithStrategy(polars_core::prelude::FillNullStrategy::Backward(limit)) => {
598
wrap_groups!(groups_dispatch::backward_fill_null, (*limit, v: Option<IdxSize>))
599
},
600
601
_ => return None,
602
})
603
}
604
605