Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/reduce/convert.rs
6940 views
1
// use polars_core::error::feature_gated;
2
use polars_plan::prelude::*;
3
use polars_utils::arena::{Arena, Node};
4
5
use super::*;
6
use crate::reduce::any_all::{new_all_reduction, new_any_reduction};
7
#[cfg(feature = "bitwise")]
8
use crate::reduce::bitwise::{
9
new_bitwise_and_reduction, new_bitwise_or_reduction, new_bitwise_xor_reduction,
10
};
11
use crate::reduce::count::CountReduce;
12
use crate::reduce::first_last::{new_first_reduction, new_last_reduction};
13
use crate::reduce::len::LenReduce;
14
use crate::reduce::mean::new_mean_reduction;
15
use crate::reduce::min_max::{new_max_reduction, new_min_reduction};
16
use crate::reduce::sum::new_sum_reduction;
17
use crate::reduce::var_std::new_var_std_reduction;
18
19
/// Converts a node into a reduction + its associated selector expression.
20
pub fn into_reduction(
21
node: Node,
22
expr_arena: &mut Arena<AExpr>,
23
schema: &Schema,
24
) -> PolarsResult<(Box<dyn GroupedReduction>, Node)> {
25
let get_dt = |node| {
26
expr_arena
27
.get(node)
28
.to_dtype(schema, expr_arena)?
29
.materialize_unknown(false)
30
};
31
let out = match expr_arena.get(node) {
32
AExpr::Agg(agg) => match agg {
33
IRAggExpr::Sum(input) => (new_sum_reduction(get_dt(*input)?), *input),
34
IRAggExpr::Mean(input) => (new_mean_reduction(get_dt(*input)?), *input),
35
IRAggExpr::Min {
36
propagate_nans,
37
input,
38
} => (new_min_reduction(get_dt(*input)?, *propagate_nans), *input),
39
IRAggExpr::Max {
40
propagate_nans,
41
input,
42
} => (new_max_reduction(get_dt(*input)?, *propagate_nans), *input),
43
IRAggExpr::Var(input, ddof) => {
44
(new_var_std_reduction(get_dt(*input)?, false, *ddof), *input)
45
},
46
IRAggExpr::Std(input, ddof) => {
47
(new_var_std_reduction(get_dt(*input)?, true, *ddof), *input)
48
},
49
IRAggExpr::First(input) => (new_first_reduction(get_dt(*input)?), *input),
50
IRAggExpr::Last(input) => (new_last_reduction(get_dt(*input)?), *input),
51
IRAggExpr::Count {
52
input,
53
include_nulls,
54
} => {
55
let count = Box::new(CountReduce::new(*include_nulls)) as Box<_>;
56
(count, *input)
57
},
58
IRAggExpr::Quantile { .. } => todo!(),
59
IRAggExpr::Median(_) => todo!(),
60
IRAggExpr::NUnique(_) => todo!(),
61
IRAggExpr::Implode(_) => todo!(),
62
IRAggExpr::AggGroups(_) => todo!(),
63
},
64
AExpr::Len => {
65
if let Some(first_column) = schema.iter_names().next() {
66
let out: Box<dyn GroupedReduction> = Box::new(LenReduce::default());
67
let expr = expr_arena.add(AExpr::Column(first_column.as_str().into()));
68
69
(out, expr)
70
} else {
71
// Support len aggregation on 0-width morsels.
72
// Notes:
73
// * We do this instead of projecting a scalar, because scalar literals don't
74
// project to the height of the DataFrame (in the PhysicalExpr impl).
75
// * This approach is not sound for `update_groups()`, but currently that case is
76
// not hit (it would need group-by -> len on empty morsels).
77
let out: Box<dyn GroupedReduction> = new_sum_reduction(DataType::IDX_DTYPE);
78
let expr = expr_arena.add(AExpr::Len);
79
80
(out, expr)
81
}
82
},
83
#[cfg(feature = "bitwise")]
84
AExpr::Function {
85
input: inner_exprs,
86
function: IRFunctionExpr::Bitwise(inner_fn),
87
options: _,
88
} => {
89
assert!(inner_exprs.len() == 1);
90
let input = inner_exprs[0].node();
91
match inner_fn {
92
IRBitwiseFunction::And => (new_bitwise_and_reduction(get_dt(input)?), input),
93
IRBitwiseFunction::Or => (new_bitwise_or_reduction(get_dt(input)?), input),
94
IRBitwiseFunction::Xor => (new_bitwise_xor_reduction(get_dt(input)?), input),
95
_ => unreachable!(),
96
}
97
},
98
99
AExpr::Function {
100
input: inner_exprs,
101
function: IRFunctionExpr::Boolean(inner_fn),
102
options: _,
103
} => {
104
assert!(inner_exprs.len() == 1);
105
let input = inner_exprs[0].node();
106
match inner_fn {
107
IRBooleanFunction::Any { ignore_nulls } => {
108
(new_any_reduction(*ignore_nulls), input)
109
},
110
IRBooleanFunction::All { ignore_nulls } => {
111
(new_all_reduction(*ignore_nulls), input)
112
},
113
_ => unreachable!(),
114
}
115
},
116
_ => unreachable!(),
117
};
118
Ok(out)
119
}
120
121