Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/mod.rs
6940 views
1
mod aggregation;
2
mod alias;
3
mod apply;
4
mod binary;
5
mod cast;
6
mod column;
7
mod count;
8
mod eval;
9
mod filter;
10
mod gather;
11
mod group_iter;
12
mod literal;
13
#[cfg(feature = "dynamic_group_by")]
14
mod rolling;
15
mod slice;
16
mod sort;
17
mod sortby;
18
mod ternary;
19
mod window;
20
21
use std::borrow::Cow;
22
use std::fmt::{Display, Formatter};
23
24
pub(crate) use aggregation::*;
25
pub(crate) use alias::*;
26
pub(crate) use apply::*;
27
use arrow::array::ArrayRef;
28
use arrow::legacy::utils::CustomIterTools;
29
pub(crate) use binary::*;
30
pub(crate) use cast::*;
31
pub(crate) use column::*;
32
pub(crate) use count::*;
33
pub(crate) use eval::*;
34
pub(crate) use filter::*;
35
pub(crate) use gather::*;
36
pub(crate) use literal::*;
37
use polars_core::prelude::*;
38
use polars_io::predicates::PhysicalIoExpr;
39
use polars_plan::prelude::*;
40
#[cfg(feature = "dynamic_group_by")]
41
pub(crate) use rolling::RollingExpr;
42
pub(crate) use slice::*;
43
pub(crate) use sort::*;
44
pub(crate) use sortby::*;
45
pub(crate) use ternary::*;
46
pub use window::window_function_format_order_by;
47
pub(crate) use window::*;
48
49
use crate::state::ExecutionState;
50
51
#[derive(Clone, Debug)]
52
pub enum AggState {
53
/// Already aggregated: `.agg_list(group_tuples)` is called
54
/// and produced a `Series` of dtype `List`
55
AggregatedList(Column),
56
/// Already aggregated: `.agg` is called on an aggregation
57
/// that produces a scalar.
58
/// think of `sum`, `mean`, `variance` like aggregations.
59
AggregatedScalar(Column),
60
/// Not yet aggregated: `agg_list` still has to be called.
61
NotAggregated(Column),
62
/// A literal scalar value.
63
LiteralScalar(Column),
64
}
65
66
impl AggState {
67
fn try_map<F>(&self, func: F) -> PolarsResult<Self>
68
where
69
F: FnOnce(&Column) -> PolarsResult<Column>,
70
{
71
Ok(match self {
72
AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
73
AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
74
AggState::LiteralScalar(c) => AggState::LiteralScalar(func(c)?),
75
AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
76
})
77
}
78
79
fn is_scalar(&self) -> bool {
80
matches!(self, Self::AggregatedScalar(_))
81
}
82
}
83
84
// lazy update strategy
85
#[cfg_attr(debug_assertions, derive(Debug))]
86
#[derive(PartialEq, Clone, Copy)]
87
pub(crate) enum UpdateGroups {
88
/// don't update groups
89
No,
90
/// use the length of the current groups to determine new sorted indexes, preferred
91
/// for performance
92
WithGroupsLen,
93
/// use the series list offsets to determine the new group lengths
94
/// this one should be used when the length has changed. Note that
95
/// the series should be aggregated state or else it will panic.
96
WithSeriesLen,
97
}
98
99
#[cfg_attr(debug_assertions, derive(Debug))]
100
pub struct AggregationContext<'a> {
101
/// Can be in one of two states
102
/// 1. already aggregated as list
103
/// 2. flat (still needs the grouptuples to aggregate)
104
state: AggState,
105
/// group tuples for AggState
106
groups: Cow<'a, GroupPositions>,
107
/// This is used to determined if we need to update the groups
108
/// into a sorted groups. We do this lazily, so that this work only is
109
/// done when the groups are needed
110
update_groups: UpdateGroups,
111
/// This is true when the Series and Groups still have all
112
/// their original values. Not the case when filtered
113
original_len: bool,
114
}
115
116
impl<'a> AggregationContext<'a> {
117
pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
118
match self.update_groups {
119
UpdateGroups::No => {},
120
UpdateGroups::WithGroupsLen => {
121
// the groups are unordered
122
// and the series is aggregated with this groups
123
// so we need to recreate new grouptuples that
124
// match the exploded Series
125
let mut offset = 0 as IdxSize;
126
127
match self.groups.as_ref().as_ref() {
128
GroupsType::Idx(groups) => {
129
let groups = groups
130
.iter()
131
.map(|g| {
132
let len = g.1.len() as IdxSize;
133
let new_offset = offset + len;
134
let out = [offset, len];
135
offset = new_offset;
136
out
137
})
138
.collect();
139
self.groups = Cow::Owned(
140
GroupsType::Slice {
141
groups,
142
rolling: false,
143
}
144
.into_sliceable(),
145
)
146
},
147
// sliced groups are already in correct order
148
GroupsType::Slice { .. } => {},
149
}
150
self.update_groups = UpdateGroups::No;
151
},
152
UpdateGroups::WithSeriesLen => {
153
let s = self.get_values().clone();
154
self.det_groups_from_list(s.as_materialized_series());
155
},
156
}
157
&self.groups
158
}
159
160
pub(crate) fn get_values(&self) -> &Column {
161
match &self.state {
162
AggState::NotAggregated(s)
163
| AggState::AggregatedScalar(s)
164
| AggState::AggregatedList(s) => s,
165
AggState::LiteralScalar(s) => s,
166
}
167
}
168
169
pub fn agg_state(&self) -> &AggState {
170
&self.state
171
}
172
173
pub(crate) fn is_not_aggregated(&self) -> bool {
174
matches!(
175
&self.state,
176
AggState::NotAggregated(_) | AggState::LiteralScalar(_)
177
)
178
}
179
180
pub(crate) fn is_aggregated(&self) -> bool {
181
!self.is_not_aggregated()
182
}
183
184
pub(crate) fn is_literal(&self) -> bool {
185
matches!(self.state, AggState::LiteralScalar(_))
186
}
187
188
/// # Arguments
189
/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
190
/// the columns dtype)
191
fn new(
192
column: Column,
193
groups: Cow<'a, GroupPositions>,
194
aggregated: bool,
195
) -> AggregationContext<'a> {
196
let series = if aggregated {
197
assert_eq!(column.len(), groups.len());
198
AggState::AggregatedScalar(column)
199
} else {
200
AggState::NotAggregated(column)
201
};
202
203
Self {
204
state: series,
205
groups,
206
update_groups: UpdateGroups::No,
207
original_len: true,
208
}
209
}
210
211
fn with_agg_state(&mut self, agg_state: AggState) {
212
self.state = agg_state;
213
}
214
215
fn from_agg_state(
216
agg_state: AggState,
217
groups: Cow<'a, GroupPositions>,
218
) -> AggregationContext<'a> {
219
Self {
220
state: agg_state,
221
groups,
222
update_groups: UpdateGroups::No,
223
original_len: true,
224
}
225
}
226
227
pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
228
self.original_len = original_len;
229
self
230
}
231
232
pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
233
self.update_groups = update;
234
self
235
}
236
237
fn det_groups_from_list(&mut self, s: &Series) {
238
let mut offset = 0 as IdxSize;
239
let list = s
240
.list()
241
.expect("impl error, should be a list at this point");
242
243
match list.chunks().len() {
244
1 => {
245
let arr = list.downcast_iter().next().unwrap();
246
let offsets = arr.offsets().as_slice();
247
248
let mut previous = 0i64;
249
let groups = offsets[1..]
250
.iter()
251
.map(|&o| {
252
let len = (o - previous) as IdxSize;
253
let new_offset = offset + len;
254
255
previous = o;
256
let out = [offset, len];
257
offset = new_offset;
258
out
259
})
260
.collect_trusted();
261
self.groups = Cow::Owned(
262
GroupsType::Slice {
263
groups,
264
rolling: false,
265
}
266
.into_sliceable(),
267
);
268
},
269
_ => {
270
let groups = {
271
self.get_values()
272
.list()
273
.expect("impl error, should be a list at this point")
274
.amortized_iter()
275
.map(|s| {
276
if let Some(s) = s {
277
let len = s.as_ref().len() as IdxSize;
278
let new_offset = offset + len;
279
let out = [offset, len];
280
offset = new_offset;
281
out
282
} else {
283
[offset, 0]
284
}
285
})
286
.collect_trusted()
287
};
288
self.groups = Cow::Owned(
289
GroupsType::Slice {
290
groups,
291
rolling: false,
292
}
293
.into_sliceable(),
294
);
295
},
296
}
297
self.update_groups = UpdateGroups::No;
298
}
299
300
/// # Arguments
301
/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
302
/// the columns dtype)
303
pub(crate) fn with_values(
304
&mut self,
305
column: Column,
306
aggregated: bool,
307
expr: Option<&Expr>,
308
) -> PolarsResult<&mut Self> {
309
self.with_values_and_args(
310
column,
311
aggregated,
312
expr,
313
false,
314
self.agg_state().is_scalar(),
315
)
316
}
317
318
pub(crate) fn with_values_and_args(
319
&mut self,
320
column: Column,
321
aggregated: bool,
322
expr: Option<&Expr>,
323
// if the applied function was a `map` instead of an `apply`
324
// this will keep functions applied over literals as literals: F(lit) = lit
325
mapped: bool,
326
returns_scalar: bool,
327
) -> PolarsResult<&mut Self> {
328
self.state = match (aggregated, column.dtype()) {
329
(true, &DataType::List(_)) if !returns_scalar => {
330
if column.len() != self.groups.len() {
331
let fmt_expr = if let Some(e) = expr {
332
format!("'{e:?}' ")
333
} else {
334
String::new()
335
};
336
polars_bail!(
337
ComputeError:
338
"aggregation expression '{}' produced a different number of elements: {} \
339
than the number of groups: {} (this is likely invalid)",
340
fmt_expr, column.len(), self.groups.len(),
341
);
342
}
343
AggState::AggregatedList(column)
344
},
345
(true, _) => AggState::AggregatedScalar(column),
346
_ => {
347
match self.state {
348
// already aggregated to sum, min even this series was flattened it never could
349
// retrieve the length before grouping, so it stays in this state.
350
AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
351
// applying a function on a literal, keeps the literal state
352
AggState::LiteralScalar(_) if column.len() == 1 && mapped => {
353
AggState::LiteralScalar(column)
354
},
355
_ => AggState::NotAggregated(column.into_column()),
356
}
357
},
358
};
359
Ok(self)
360
}
361
362
pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
363
self.state = AggState::LiteralScalar(column);
364
self
365
}
366
367
/// Update the group tuples
368
pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
369
if let AggState::AggregatedList(_) = self.agg_state() {
370
// In case of new groups, a series always needs to be flattened
371
self.with_values(self.flat_naive().into_owned(), false, None)
372
.unwrap();
373
}
374
self.groups = Cow::Owned(groups);
375
// make sure that previous setting is not used
376
self.update_groups = UpdateGroups::No;
377
self
378
}
379
380
pub(crate) fn _implode_no_agg(&mut self) {
381
match self.state.clone() {
382
AggState::NotAggregated(_) => {
383
let _ = self.aggregated();
384
let AggState::AggregatedList(s) = self.state.clone() else {
385
unreachable!()
386
};
387
self.state = AggState::AggregatedScalar(s);
388
},
389
AggState::AggregatedList(s) => {
390
self.state = AggState::AggregatedScalar(s);
391
},
392
_ => unreachable!("should only be called in non-agg/list-agg state by aggregation.rs"),
393
}
394
}
395
396
/// Aggregate into `ListChunked`.
397
pub fn aggregated_as_list<'b>(&'b mut self) -> Cow<'b, ListChunked> {
398
self.aggregated();
399
let out = self.get_values();
400
match self.agg_state() {
401
AggState::AggregatedScalar(_) => Cow::Owned(out.as_list()),
402
_ => Cow::Borrowed(out.list().unwrap()),
403
}
404
}
405
406
/// Get the aggregated version of the series.
407
pub fn aggregated(&mut self) -> Column {
408
// we clone, because we only want to call `self.groups()` if needed.
409
// self groups may instantiate new groups and thus can be expensive.
410
match self.state.clone() {
411
AggState::NotAggregated(s) => {
412
// The groups are determined lazily and in case of a flat/non-aggregated
413
// series we use the groups to aggregate the list
414
// because this is lazy, we first must to update the groups
415
// by calling .groups()
416
self.groups();
417
#[cfg(debug_assertions)]
418
{
419
if self.groups.len() > s.len() {
420
polars_warn!(
421
"groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"
422
)
423
}
424
}
425
426
// SAFETY:
427
// groups are in bounds
428
let out = unsafe { s.agg_list(&self.groups) };
429
self.state = AggState::AggregatedList(out.clone());
430
431
self.update_groups = UpdateGroups::WithGroupsLen;
432
out
433
},
434
AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
435
AggState::LiteralScalar(s) => {
436
self.groups();
437
let rows = self.groups.len();
438
let s = s.implode().unwrap();
439
let s = s.new_from_index(0, rows);
440
let s = s.into_column();
441
self.state = AggState::AggregatedList(s.clone());
442
self.with_update_groups(UpdateGroups::WithSeriesLen);
443
s.clone()
444
},
445
}
446
}
447
448
/// Get the final aggregated version of the series.
449
pub fn finalize(&mut self) -> Column {
450
// we clone, because we only want to call `self.groups()` if needed.
451
// self groups may instantiate new groups and thus can be expensive.
452
match &self.state {
453
AggState::LiteralScalar(c) => {
454
let c = c.clone();
455
self.groups();
456
let rows = self.groups.len();
457
c.new_from_index(0, rows)
458
},
459
_ => self.aggregated(),
460
}
461
}
462
463
// If a binary or ternary function has both of these branches true, it should
464
// flatten the list
465
fn arity_should_explode(&self) -> bool {
466
use AggState::*;
467
match self.agg_state() {
468
LiteralScalar(s) => s.len() == 1,
469
AggregatedScalar(_) => true,
470
_ => false,
471
}
472
}
473
474
pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
475
let _ = self.groups();
476
let groups = self.groups;
477
match self.state {
478
AggState::NotAggregated(c) => (c, groups),
479
AggState::AggregatedScalar(c) => (c, groups),
480
AggState::LiteralScalar(c) => (c, groups),
481
AggState::AggregatedList(c) => {
482
let flattened = c.explode(true).unwrap();
483
let groups = groups.into_owned();
484
// unroll the possible flattened state
485
// say we have groups with overlapping windows:
486
//
487
// offset, len
488
// 0, 1
489
// 0, 2
490
// 0, 4
491
//
492
// gets aggregation
493
//
494
// [0]
495
// [0, 1],
496
// [0, 1, 2, 3]
497
//
498
// before aggregation the column was
499
// [0, 1, 2, 3]
500
// but explode on this list yields
501
// [0, 0, 1, 0, 1, 2, 3]
502
//
503
// so we unroll the groups as
504
//
505
// [0, 1]
506
// [1, 2]
507
// [3, 4]
508
let groups = groups.unroll();
509
(flattened, Cow::Owned(groups))
510
},
511
}
512
}
513
514
/// Get the not-aggregated version of the series.
515
/// Note that we call it naive, because if a previous expr
516
/// has filtered or sorted this, this information is in the
517
/// group tuples not the flattened series.
518
pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
519
match &self.state {
520
AggState::NotAggregated(c) => Cow::Borrowed(c),
521
AggState::AggregatedList(c) => {
522
#[cfg(debug_assertions)]
523
{
524
// panic so we find cases where we accidentally explode overlapping groups
525
// we don't want this as this can create a lot of data
526
if let GroupsType::Slice { rolling: true, .. } = self.groups.as_ref().as_ref() {
527
panic!(
528
"implementation error, polars should not hit this branch for overlapping groups"
529
)
530
}
531
}
532
533
// We should not insert nulls, otherwise the offsets in the groups will not be correct.
534
Cow::Owned(c.explode(true).unwrap())
535
},
536
AggState::AggregatedScalar(c) => Cow::Borrowed(c),
537
AggState::LiteralScalar(c) => Cow::Borrowed(c),
538
}
539
}
540
541
/// Take the series.
542
pub(crate) fn take(&mut self) -> Column {
543
let c = match &mut self.state {
544
AggState::NotAggregated(c)
545
| AggState::AggregatedScalar(c)
546
| AggState::AggregatedList(c) => c,
547
AggState::LiteralScalar(c) => c,
548
};
549
std::mem::take(c)
550
}
551
}
552
553
/// Take a DataFrame and evaluate the expressions.
554
/// Implement this for Column, lt, eq, etc
555
pub trait PhysicalExpr: Send + Sync {
556
fn as_expression(&self) -> Option<&Expr> {
557
None
558
}
559
560
/// Take a DataFrame and evaluate the expression.
561
fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
562
563
/// Some expression that are not aggregations can be done per group
564
/// Think of sort, slice, filter, shift, etc.
565
/// defaults to ignoring the group
566
///
567
/// This method is called by an aggregation function.
568
///
569
/// In case of a simple expr, like 'column', the groups are ignored and the column is returned.
570
/// In case of an expr where group behavior makes sense, this method is called.
571
/// For a filter operation for instance, a Series is created per groups and filtered.
572
///
573
/// An implementation of this method may apply an aggregation on the groups only. For instance
574
/// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per
575
/// group. The implementation then has to return the `Series` exploded (because a later aggregation
576
/// will use the group tuples to aggregate). The group tuples also have to be updated, because
577
/// aggregation to a list sorts the exploded `Series` by group.
578
///
579
/// This has some gotcha's. An implementation may also change the group tuples instead of
580
/// the `Series`.
581
///
582
// we allow this because we pass the vec to the Cow
583
// Note to self: Don't be smart and dispatch to evaluate as default implementation
584
// this means filters will be incorrect and lead to invalid results down the line
585
#[allow(clippy::ptr_arg)]
586
fn evaluate_on_groups<'a>(
587
&self,
588
df: &DataFrame,
589
groups: &'a GroupPositions,
590
state: &ExecutionState,
591
) -> PolarsResult<AggregationContext<'a>>;
592
593
/// Get the output field of this expr
594
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
595
596
/// Convert to a partitioned aggregator.
597
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
598
None
599
}
600
601
fn is_literal(&self) -> bool {
602
false
603
}
604
fn is_scalar(&self) -> bool;
605
}
606
607
impl Display for &dyn PhysicalExpr {
608
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
609
match self.as_expression() {
610
None => Ok(()),
611
Some(e) => write!(f, "{e:?}"),
612
}
613
}
614
}
615
616
/// Wrapper struct that allow us to use a PhysicalExpr in polars-io.
617
///
618
/// This is used to filter rows during the scan of file.
619
pub struct PhysicalIoHelper {
620
pub expr: Arc<dyn PhysicalExpr>,
621
pub has_window_function: bool,
622
}
623
624
impl PhysicalIoExpr for PhysicalIoHelper {
625
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
626
let mut state: ExecutionState = Default::default();
627
if self.has_window_function {
628
state.insert_has_window_function_flag();
629
}
630
self.expr.evaluate(df, &state).map(|c| {
631
// IO expression result should be boolean-typed.
632
debug_assert_eq!(c.dtype(), &DataType::Boolean);
633
(if c.len() == 1 && df.height() != 1 {
634
// filter(lit(True)) will hit here.
635
c.new_from_index(0, df.height())
636
} else {
637
c
638
})
639
.take_materialized_series()
640
})
641
}
642
}
643
644
pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
645
let has_window_function = if let Some(expr) = expr.as_expression() {
646
expr.into_iter()
647
.any(|expr| matches!(expr, Expr::Window { .. }))
648
} else {
649
false
650
};
651
Arc::new(PhysicalIoHelper {
652
expr,
653
has_window_function,
654
}) as Arc<dyn PhysicalIoExpr>
655
}
656
657
pub trait PartitionedAggregation: Send + Sync + PhysicalExpr {
658
/// This is called in partitioned aggregation.
659
/// Partitioned results may differ from aggregation results.
660
/// For instance, for a `mean` operation a partitioned result
661
/// needs to return the `sum` and the `valid_count` (length - null count).
662
///
663
/// A final aggregation can then take the sum of sums and sum of valid_counts
664
/// to produce a final mean.
665
#[allow(clippy::ptr_arg)]
666
fn evaluate_partitioned(
667
&self,
668
df: &DataFrame,
669
groups: &GroupPositions,
670
state: &ExecutionState,
671
) -> PolarsResult<Column>;
672
673
/// Called to merge all the partitioned results in a final aggregate.
674
#[allow(clippy::ptr_arg)]
675
fn finalize(
676
&self,
677
partitioned: Column,
678
groups: &GroupPositions,
679
state: &ExecutionState,
680
) -> PolarsResult<Column>;
681
}
682
683