Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/mod.rs
8424 views
1
mod builder;
2
mod equality;
3
mod evaluate;
4
mod function_expr;
5
#[cfg(feature = "cse")]
6
mod hash;
7
mod minterm_iter;
8
pub mod predicates;
9
mod scalar;
10
mod schema;
11
mod traverse;
12
13
use std::hash::{Hash, Hasher};
14
15
pub use function_expr::*;
16
#[cfg(feature = "cse")]
17
pub(super) use hash::traverse_and_hash_aexpr;
18
pub use minterm_iter::MintermIter;
19
use polars_compute::rolling::QuantileMethod;
20
use polars_core::chunked_array::cast::CastOptions;
21
use polars_core::prelude::*;
22
use polars_core::utils::{get_time_units, try_get_supertype};
23
use polars_utils::arena::{Arena, Node};
24
pub use scalar::{is_length_preserving_ae, is_scalar_ae};
25
use strum_macros::IntoStaticStr;
26
pub use traverse::*;
27
mod properties;
28
pub use aexpr::function_expr::schema::FieldsMapper;
29
pub use builder::AExprBuilder;
30
pub use evaluate::{constant_evaluate, into_column};
31
pub use properties::*;
32
pub use schema::ToFieldContext;
33
34
use crate::constants::LEN;
35
use crate::prelude::*;
36
37
#[derive(Clone, Debug, IntoStaticStr)]
38
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
39
pub enum IRAggExpr {
40
Min {
41
input: Node,
42
propagate_nans: bool,
43
},
44
Max {
45
input: Node,
46
propagate_nans: bool,
47
},
48
Median(Node),
49
NUnique(Node),
50
Item {
51
input: Node,
52
/// Return a missing value if there are no values.
53
allow_empty: bool,
54
},
55
First(Node),
56
FirstNonNull(Node),
57
Last(Node),
58
LastNonNull(Node),
59
Mean(Node),
60
Implode(Node),
61
Quantile {
62
expr: Node,
63
quantile: Node,
64
method: QuantileMethod,
65
},
66
Sum(Node),
67
Count {
68
input: Node,
69
include_nulls: bool,
70
},
71
Std(Node, u8),
72
Var(Node, u8),
73
AggGroups(Node),
74
}
75
76
impl Hash for IRAggExpr {
77
fn hash<H: Hasher>(&self, state: &mut H) {
78
std::mem::discriminant(self).hash(state);
79
match self {
80
Self::Min {
81
input: _,
82
propagate_nans,
83
}
84
| Self::Max {
85
input: _,
86
propagate_nans,
87
} => propagate_nans.hash(state),
88
Self::Quantile {
89
method: interpol, ..
90
} => interpol.hash(state),
91
Self::Std(_, v) | Self::Var(_, v) => v.hash(state),
92
Self::Count {
93
input: _,
94
include_nulls,
95
} => include_nulls.hash(state),
96
_ => {},
97
}
98
}
99
}
100
101
impl IRAggExpr {
102
pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {
103
use IRAggExpr::*;
104
match (self, other) {
105
(
106
Min {
107
propagate_nans: l, ..
108
},
109
Min {
110
propagate_nans: r, ..
111
},
112
) => l == r,
113
(
114
Max {
115
propagate_nans: l, ..
116
},
117
Max {
118
propagate_nans: r, ..
119
},
120
) => l == r,
121
(Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,
122
(Std(_, l), Std(_, r)) => l == r,
123
(Var(_, l), Var(_, r)) => l == r,
124
_ => std::mem::discriminant(self) == std::mem::discriminant(other),
125
}
126
}
127
}
128
129
impl From<IRAggExpr> for GroupByMethod {
130
fn from(value: IRAggExpr) -> Self {
131
use IRAggExpr::*;
132
match value {
133
Min {
134
input: _,
135
propagate_nans,
136
} => {
137
if propagate_nans {
138
GroupByMethod::NanMin
139
} else {
140
GroupByMethod::Min
141
}
142
},
143
Max {
144
input: _,
145
propagate_nans,
146
} => {
147
if propagate_nans {
148
GroupByMethod::NanMax
149
} else {
150
GroupByMethod::Max
151
}
152
},
153
Median(_) => GroupByMethod::Median,
154
NUnique(_) => GroupByMethod::NUnique,
155
First(_) => GroupByMethod::First,
156
FirstNonNull(_) => GroupByMethod::FirstNonNull,
157
Last(_) => GroupByMethod::Last,
158
LastNonNull(_) => GroupByMethod::LastNonNull,
159
Item { allow_empty, .. } => GroupByMethod::Item { allow_empty },
160
Mean(_) => GroupByMethod::Mean,
161
Implode(_) => GroupByMethod::Implode,
162
Sum(_) => GroupByMethod::Sum,
163
Count {
164
input: _,
165
include_nulls,
166
} => GroupByMethod::Count { include_nulls },
167
Std(_, ddof) => GroupByMethod::Std(ddof),
168
Var(_, ddof) => GroupByMethod::Var(ddof),
169
AggGroups(_) => GroupByMethod::Groups,
170
// Multi-input aggregations.
171
Quantile { .. } => unreachable!(),
172
}
173
}
174
}
175
176
/// IR expression node that is allocated in an [`Arena`][polars_utils::arena::Arena].
177
#[derive(Clone, Debug, Default)]
178
#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]
179
pub enum AExpr {
180
/// Values in a `eval` context.
181
///
182
/// Equivalent of `pl.element()`.
183
Element,
184
Explode {
185
expr: Node,
186
options: ExplodeOptions,
187
},
188
Column(PlSmallStr),
189
/// Struct field value in a `struct.with_fields` context.
190
///
191
/// Equivalent of `pl.field(name)`.
192
#[cfg(feature = "dtype-struct")]
193
StructField(PlSmallStr),
194
Literal(LiteralValue),
195
BinaryExpr {
196
left: Node,
197
op: Operator,
198
right: Node,
199
},
200
Cast {
201
expr: Node,
202
dtype: DataType,
203
options: CastOptions,
204
},
205
Sort {
206
expr: Node,
207
options: SortOptions,
208
},
209
Gather {
210
expr: Node,
211
idx: Node,
212
returns_scalar: bool,
213
null_on_oob: bool,
214
},
215
SortBy {
216
expr: Node,
217
by: Vec<Node>,
218
sort_options: SortMultipleOptions,
219
},
220
Filter {
221
input: Node,
222
by: Node,
223
},
224
Agg(IRAggExpr),
225
Ternary {
226
predicate: Node,
227
truthy: Node,
228
falsy: Node,
229
},
230
AnonymousAgg {
231
input: Vec<ExprIR>,
232
fmt_str: Box<PlSmallStr>,
233
function: OpaqueStreamingAgg,
234
},
235
AnonymousFunction {
236
input: Vec<ExprIR>,
237
function: OpaqueColumnUdf,
238
options: FunctionOptions,
239
fmt_str: Box<PlSmallStr>,
240
},
241
/// Evaluates the `evaluation` expression on the output of the `expr`.
242
///
243
/// Consequently, `expr` is an input and `evaluation` is not and needs a different schema.
244
Eval {
245
expr: Node,
246
247
/// An expression that is guaranteed to not contain any column reference beyond
248
/// `pl.element()` which refers to `pl.col("")`.
249
evaluation: Node,
250
251
variant: EvalVariant,
252
},
253
#[cfg(feature = "dtype-struct")]
254
StructEval {
255
expr: Node,
256
evaluation: Vec<ExprIR>,
257
},
258
Function {
259
/// Function arguments
260
/// Some functions rely on aliases,
261
/// for instance assignment of struct fields.
262
/// Therefor we need [`ExprIr`].
263
input: Vec<ExprIR>,
264
/// function to apply
265
function: IRFunctionExpr,
266
options: FunctionOptions,
267
},
268
Over {
269
function: Node,
270
partition_by: Vec<Node>,
271
order_by: Option<(Node, SortOptions)>,
272
mapping: WindowMapping,
273
},
274
#[cfg(feature = "dynamic_group_by")]
275
Rolling {
276
function: Node,
277
index_column: Node,
278
period: Duration,
279
offset: Duration,
280
closed_window: ClosedWindow,
281
},
282
Slice {
283
input: Node,
284
offset: Node,
285
length: Node,
286
},
287
#[default]
288
Len,
289
}
290
291
impl AExpr {
292
#[cfg(feature = "cse")]
293
pub(crate) fn col(name: PlSmallStr) -> Self {
294
AExpr::Column(name)
295
}
296
297
#[recursive::recursive]
298
pub fn is_scalar(&self, arena: &Arena<AExpr>) -> bool {
299
match self {
300
AExpr::Element => false,
301
AExpr::Literal(lv) => lv.is_scalar(),
302
AExpr::Function { options, input, .. }
303
| AExpr::AnonymousFunction { options, input, .. } => {
304
if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
305
true
306
} else if options.is_elementwise()
307
|| options.flags.contains(FunctionFlags::LENGTH_PRESERVING)
308
{
309
input.iter().all(|e| e.is_scalar(arena))
310
} else {
311
false
312
}
313
},
314
AExpr::BinaryExpr { left, right, .. } => {
315
is_scalar_ae(*left, arena) && is_scalar_ae(*right, arena)
316
},
317
AExpr::Ternary {
318
predicate,
319
truthy,
320
falsy,
321
} => {
322
is_scalar_ae(*predicate, arena)
323
&& is_scalar_ae(*truthy, arena)
324
&& is_scalar_ae(*falsy, arena)
325
},
326
AExpr::Agg(_) | AExpr::AnonymousAgg { .. } | AExpr::Len => true,
327
AExpr::Cast { expr, .. } => is_scalar_ae(*expr, arena),
328
AExpr::Eval { expr, variant, .. } => {
329
variant.is_length_preserving() && is_scalar_ae(*expr, arena)
330
},
331
#[cfg(feature = "dtype-struct")]
332
AExpr::StructEval { expr, .. } => is_scalar_ae(*expr, arena),
333
AExpr::Sort { expr, .. } => is_scalar_ae(*expr, arena),
334
AExpr::Gather { returns_scalar, .. } => *returns_scalar,
335
AExpr::SortBy { expr, .. } => is_scalar_ae(*expr, arena),
336
337
// Over and Rolling implicitly zip with the context and thus are never scalars
338
AExpr::Over { .. } => false,
339
#[cfg(feature = "dynamic_group_by")]
340
AExpr::Rolling { .. } => false,
341
342
AExpr::Explode { .. }
343
| AExpr::Column(_)
344
| AExpr::Filter { .. }
345
| AExpr::Slice { .. } => false,
346
#[cfg(feature = "dtype-struct")]
347
AExpr::StructField(_) => false,
348
}
349
}
350
351
#[recursive::recursive]
352
pub fn is_length_preserving(&self, arena: &Arena<AExpr>) -> bool {
353
fn broadcasting_input_length_preserving(
354
n: impl IntoIterator<Item = Node>,
355
arena: &Arena<AExpr>,
356
) -> bool {
357
let mut num_items = 0;
358
let mut num_length_preserving = 0;
359
let mut num_scalar_or_length_preserving = 0;
360
361
for n in n {
362
num_items += 1;
363
364
if is_length_preserving_ae(n, arena) {
365
num_length_preserving += 1;
366
num_scalar_or_length_preserving += 1;
367
} else if is_scalar_ae(n, arena) {
368
num_scalar_or_length_preserving += 1;
369
}
370
}
371
372
num_length_preserving > 0 && num_scalar_or_length_preserving == num_items
373
}
374
375
match self {
376
AExpr::Element => true,
377
AExpr::Column(_) => true,
378
#[cfg(feature = "dtype-struct")]
379
AExpr::StructField(_) => true,
380
381
// Over and Rolling implicitly zip with the context and thus should always be length
382
// preserving
383
AExpr::Over { mapping, .. } => !matches!(mapping, WindowMapping::Explode),
384
#[cfg(feature = "dynamic_group_by")]
385
AExpr::Rolling { .. } => true,
386
387
AExpr::AnonymousAgg { .. } | AExpr::Literal(_) | AExpr::Agg(_) | AExpr::Len => false,
388
AExpr::Function { options, input, .. }
389
| AExpr::AnonymousFunction { options, input, .. } => {
390
if options.flags.is_elementwise() {
391
broadcasting_input_length_preserving(input.iter().map(|e| e.node()), arena)
392
} else if options.flags.is_length_preserving() {
393
input.iter().all(|e| e.is_length_preserving(arena))
394
} else {
395
false
396
}
397
},
398
AExpr::BinaryExpr { left, right, .. } => {
399
broadcasting_input_length_preserving([*left, *right], arena)
400
},
401
AExpr::Ternary {
402
predicate,
403
truthy,
404
falsy,
405
} => broadcasting_input_length_preserving([*predicate, *truthy, *falsy], arena),
406
AExpr::Cast { expr, .. } => is_length_preserving_ae(*expr, arena),
407
AExpr::Eval { expr, variant, .. } => {
408
variant.is_length_preserving() && is_length_preserving_ae(*expr, arena)
409
},
410
#[cfg(feature = "dtype-struct")]
411
AExpr::StructEval { expr, .. } => is_length_preserving_ae(*expr, arena),
412
AExpr::Sort { expr, .. } => is_length_preserving_ae(*expr, arena),
413
AExpr::Gather {
414
expr: _,
415
idx,
416
returns_scalar,
417
null_on_oob: _,
418
} => !returns_scalar && is_length_preserving_ae(*idx, arena),
419
AExpr::SortBy { expr, by, .. } => broadcasting_input_length_preserving(
420
std::iter::once(*expr).chain(by.iter().copied()),
421
arena,
422
),
423
424
AExpr::Explode { .. } | AExpr::Filter { .. } | AExpr::Slice { .. } => false,
425
}
426
}
427
428
/// Is the top-level expression fallible based on the data values.
429
pub fn is_fallible_top_level(&self, arena: &Arena<AExpr>) -> bool {
430
#[allow(clippy::collapsible_match, clippy::match_like_matches_macro)]
431
match self {
432
AExpr::Function {
433
input, function, ..
434
} => match function {
435
IRFunctionExpr::ListExpr(f) => match f {
436
IRListFunction::Get(false) => true,
437
#[cfg(feature = "list_gather")]
438
IRListFunction::Gather(false) => true,
439
_ => false,
440
},
441
#[cfg(feature = "dtype-array")]
442
IRFunctionExpr::ArrayExpr(f) => match f {
443
IRArrayFunction::Get(false) => true,
444
_ => false,
445
},
446
#[cfg(feature = "replace")]
447
IRFunctionExpr::ReplaceStrict { .. } => true,
448
#[cfg(all(feature = "strings", feature = "temporal"))]
449
IRFunctionExpr::StringExpr(f) => match f {
450
IRStringFunction::Strptime(_, strptime_options) => {
451
debug_assert!(input.len() <= 2);
452
453
let ambiguous_arg_is_infallible_scalar = input
454
.get(1)
455
.map(|x| arena.get(x.node()))
456
.is_some_and(|ae| match ae {
457
AExpr::Literal(lv) => {
458
lv.extract_str().is_some_and(|ambiguous| match ambiguous {
459
"earliest" | "latest" | "null" => true,
460
"raise" => false,
461
v => {
462
if cfg!(debug_assertions) {
463
panic!("unhandled parameter to ambiguous: {v}")
464
}
465
false
466
},
467
})
468
},
469
_ => false,
470
});
471
472
let ambiguous_is_fallible = !ambiguous_arg_is_infallible_scalar;
473
474
!matches!(arena.get(input[0].node()), AExpr::Literal(_))
475
&& (strptime_options.strict || ambiguous_is_fallible)
476
},
477
_ => false,
478
},
479
_ => false,
480
},
481
AExpr::Cast {
482
expr,
483
dtype: _,
484
options: CastOptions::Strict,
485
} => !matches!(arena.get(*expr), AExpr::Literal(_)),
486
_ => false,
487
}
488
}
489
}
490
491
#[recursive::recursive]
492
pub fn deep_clone_ae(ae: Node, arena: &mut Arena<AExpr>) -> Node {
493
let slf = arena.get(ae).clone();
494
495
let mut children = vec![];
496
slf.children_rev(&mut children);
497
for child in &mut children {
498
*child = deep_clone_ae(*child, arena);
499
}
500
children.reverse();
501
502
arena.add(slf.replace_children(&children))
503
}
504
505