Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/mod.rs
6940 views
1
mod builder;
2
mod equality;
3
mod evaluate;
4
mod function_expr;
5
#[cfg(feature = "cse")]
6
mod hash;
7
mod minterm_iter;
8
pub mod predicates;
9
mod scalar;
10
mod schema;
11
mod traverse;
12
13
use std::hash::{Hash, Hasher};
14
15
pub use function_expr::*;
16
#[cfg(feature = "cse")]
17
pub(super) use hash::traverse_and_hash_aexpr;
18
pub use minterm_iter::MintermIter;
19
use polars_compute::rolling::QuantileMethod;
20
use polars_core::chunked_array::cast::CastOptions;
21
use polars_core::prelude::*;
22
use polars_core::utils::{get_time_units, try_get_supertype};
23
use polars_utils::arena::{Arena, Node};
24
pub use scalar::is_scalar_ae;
25
use strum_macros::IntoStaticStr;
26
pub use traverse::*;
27
mod properties;
28
pub use aexpr::function_expr::schema::FieldsMapper;
29
pub use builder::AExprBuilder;
30
pub use properties::*;
31
32
use crate::constants::LEN;
33
use crate::plans::Context;
34
use crate::prelude::*;
35
36
#[derive(Clone, Debug, IntoStaticStr)]
37
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
38
pub enum IRAggExpr {
39
Min {
40
input: Node,
41
propagate_nans: bool,
42
},
43
Max {
44
input: Node,
45
propagate_nans: bool,
46
},
47
Median(Node),
48
NUnique(Node),
49
First(Node),
50
Last(Node),
51
Mean(Node),
52
Implode(Node),
53
Quantile {
54
expr: Node,
55
quantile: Node,
56
method: QuantileMethod,
57
},
58
Sum(Node),
59
Count {
60
input: Node,
61
include_nulls: bool,
62
},
63
Std(Node, u8),
64
Var(Node, u8),
65
AggGroups(Node),
66
}
67
68
impl Hash for IRAggExpr {
69
fn hash<H: Hasher>(&self, state: &mut H) {
70
std::mem::discriminant(self).hash(state);
71
match self {
72
Self::Min {
73
input: _,
74
propagate_nans,
75
}
76
| Self::Max {
77
input: _,
78
propagate_nans,
79
} => propagate_nans.hash(state),
80
Self::Quantile {
81
method: interpol, ..
82
} => interpol.hash(state),
83
Self::Std(_, v) | Self::Var(_, v) => v.hash(state),
84
Self::Count {
85
input: _,
86
include_nulls,
87
} => include_nulls.hash(state),
88
_ => {},
89
}
90
}
91
}
92
93
impl IRAggExpr {
94
pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {
95
use IRAggExpr::*;
96
match (self, other) {
97
(
98
Min {
99
propagate_nans: l, ..
100
},
101
Min {
102
propagate_nans: r, ..
103
},
104
) => l == r,
105
(
106
Max {
107
propagate_nans: l, ..
108
},
109
Max {
110
propagate_nans: r, ..
111
},
112
) => l == r,
113
(Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,
114
(Std(_, l), Std(_, r)) => l == r,
115
(Var(_, l), Var(_, r)) => l == r,
116
_ => std::mem::discriminant(self) == std::mem::discriminant(other),
117
}
118
}
119
}
120
121
impl From<IRAggExpr> for GroupByMethod {
122
fn from(value: IRAggExpr) -> Self {
123
use IRAggExpr::*;
124
match value {
125
Min {
126
input: _,
127
propagate_nans,
128
} => {
129
if propagate_nans {
130
GroupByMethod::NanMin
131
} else {
132
GroupByMethod::Min
133
}
134
},
135
Max {
136
input: _,
137
propagate_nans,
138
} => {
139
if propagate_nans {
140
GroupByMethod::NanMax
141
} else {
142
GroupByMethod::Max
143
}
144
},
145
Median(_) => GroupByMethod::Median,
146
NUnique(_) => GroupByMethod::NUnique,
147
First(_) => GroupByMethod::First,
148
Last(_) => GroupByMethod::Last,
149
Mean(_) => GroupByMethod::Mean,
150
Implode(_) => GroupByMethod::Implode,
151
Sum(_) => GroupByMethod::Sum,
152
Count {
153
input: _,
154
include_nulls,
155
} => GroupByMethod::Count { include_nulls },
156
Std(_, ddof) => GroupByMethod::Std(ddof),
157
Var(_, ddof) => GroupByMethod::Var(ddof),
158
AggGroups(_) => GroupByMethod::Groups,
159
Quantile { .. } => unreachable!(),
160
}
161
}
162
}
163
164
/// IR expression node that is allocated in an [`Arena`][polars_utils::arena::Arena].
165
#[derive(Clone, Debug, Default)]
166
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
167
pub enum AExpr {
168
Explode {
169
expr: Node,
170
skip_empty: bool,
171
},
172
Column(PlSmallStr),
173
Literal(LiteralValue),
174
BinaryExpr {
175
left: Node,
176
op: Operator,
177
right: Node,
178
},
179
Cast {
180
expr: Node,
181
dtype: DataType,
182
options: CastOptions,
183
},
184
Sort {
185
expr: Node,
186
options: SortOptions,
187
},
188
Gather {
189
expr: Node,
190
idx: Node,
191
returns_scalar: bool,
192
},
193
SortBy {
194
expr: Node,
195
by: Vec<Node>,
196
sort_options: SortMultipleOptions,
197
},
198
Filter {
199
input: Node,
200
by: Node,
201
},
202
Agg(IRAggExpr),
203
Ternary {
204
predicate: Node,
205
truthy: Node,
206
falsy: Node,
207
},
208
AnonymousFunction {
209
input: Vec<ExprIR>,
210
function: OpaqueColumnUdf,
211
options: FunctionOptions,
212
fmt_str: Box<PlSmallStr>,
213
},
214
/// Evaluates the `evaluation` expression on the output of the `expr`.
215
///
216
/// Consequently, `expr` is an input and `evaluation` is not and needs a different schema.
217
Eval {
218
expr: Node,
219
220
/// An expression that is guaranteed to not contain any column reference beyond
221
/// `pl.element()` which refers to `pl.col("")`.
222
evaluation: Node,
223
224
variant: EvalVariant,
225
},
226
Function {
227
/// Function arguments
228
/// Some functions rely on aliases,
229
/// for instance assignment of struct fields.
230
/// Therefor we need [`ExprIr`].
231
input: Vec<ExprIR>,
232
/// function to apply
233
function: IRFunctionExpr,
234
options: FunctionOptions,
235
},
236
Window {
237
function: Node,
238
partition_by: Vec<Node>,
239
order_by: Option<(Node, SortOptions)>,
240
options: WindowType,
241
},
242
Slice {
243
input: Node,
244
offset: Node,
245
length: Node,
246
},
247
#[default]
248
Len,
249
}
250
251
impl AExpr {
252
#[cfg(feature = "cse")]
253
pub(crate) fn col(name: PlSmallStr) -> Self {
254
AExpr::Column(name)
255
}
256
257
/// This should be a 1 on 1 copy of the get_type method of Expr until Expr is completely phased out.
258
pub fn get_dtype(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<DataType> {
259
self.to_field(schema, arena).map(|f| f.dtype().clone())
260
}
261
262
#[recursive::recursive]
263
pub fn is_scalar(&self, arena: &Arena<AExpr>) -> bool {
264
match self {
265
AExpr::Literal(lv) => lv.is_scalar(),
266
AExpr::Function { options, input, .. }
267
| AExpr::AnonymousFunction { options, input, .. } => {
268
if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
269
true
270
} else if options.is_elementwise()
271
|| options.flags.contains(FunctionFlags::LENGTH_PRESERVING)
272
{
273
input.iter().all(|e| e.is_scalar(arena))
274
} else {
275
false
276
}
277
},
278
AExpr::BinaryExpr { left, right, .. } => {
279
is_scalar_ae(*left, arena) && is_scalar_ae(*right, arena)
280
},
281
AExpr::Ternary {
282
predicate,
283
truthy,
284
falsy,
285
} => {
286
is_scalar_ae(*predicate, arena)
287
&& is_scalar_ae(*truthy, arena)
288
&& is_scalar_ae(*falsy, arena)
289
},
290
AExpr::Agg(_) | AExpr::Len => true,
291
AExpr::Cast { expr, .. } => is_scalar_ae(*expr, arena),
292
AExpr::Eval { expr, variant, .. } => match variant {
293
EvalVariant::List => is_scalar_ae(*expr, arena),
294
EvalVariant::Cumulative { .. } => is_scalar_ae(*expr, arena),
295
},
296
AExpr::Sort { expr, .. } => is_scalar_ae(*expr, arena),
297
AExpr::Gather { returns_scalar, .. } => *returns_scalar,
298
AExpr::SortBy { expr, .. } => is_scalar_ae(*expr, arena),
299
AExpr::Window { function, .. } => is_scalar_ae(*function, arena),
300
AExpr::Explode { .. }
301
| AExpr::Column(_)
302
| AExpr::Filter { .. }
303
| AExpr::Slice { .. } => false,
304
}
305
}
306
}
307
308