Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/mod.rs
8421 views
1
mod aggregation;
2
mod alias;
3
mod apply;
4
mod binary;
5
mod cast;
6
mod column;
7
mod count;
8
mod element;
9
mod eval;
10
#[cfg(feature = "dtype-struct")]
11
mod field;
12
mod filter;
13
mod gather;
14
mod group_iter;
15
mod literal;
16
#[cfg(feature = "dynamic_group_by")]
17
mod rolling;
18
mod slice;
19
mod sort;
20
mod sortby;
21
#[cfg(feature = "dtype-struct")]
22
mod structeval;
23
mod ternary;
24
mod window;
25
26
use std::borrow::Cow;
27
use std::fmt::{Display, Formatter};
28
29
pub(crate) use aggregation::*;
30
pub(crate) use alias::*;
31
pub(crate) use apply::*;
32
use arrow::array::ArrayRef;
33
use arrow::bitmap::MutableBitmap;
34
use arrow::legacy::utils::CustomIterTools;
35
pub(crate) use binary::*;
36
pub(crate) use cast::*;
37
pub(crate) use column::*;
38
pub(crate) use count::*;
39
pub(crate) use element::*;
40
pub(crate) use eval::*;
41
#[cfg(feature = "dtype-struct")]
42
pub(crate) use field::*;
43
pub(crate) use filter::*;
44
pub(crate) use gather::*;
45
pub(crate) use literal::*;
46
use polars_core::prelude::*;
47
use polars_io::predicates::PhysicalIoExpr;
48
use polars_plan::prelude::*;
49
#[cfg(feature = "dynamic_group_by")]
50
pub(crate) use rolling::RollingExpr;
51
pub(crate) use slice::*;
52
pub(crate) use sort::*;
53
pub(crate) use sortby::*;
54
#[cfg(feature = "dtype-struct")]
55
pub(crate) use structeval::*;
56
pub(crate) use ternary::*;
57
pub use window::window_function_format_order_by;
58
pub(crate) use window::*;
59
60
use crate::state::ExecutionState;
61
62
#[derive(Clone, Debug)]
63
pub enum AggState {
64
/// Already aggregated: `.agg_list(group_tuples)` is called
65
/// and produced a `Series` of dtype `List`
66
AggregatedList(Column),
67
/// Already aggregated: `.agg` is called on an aggregation
68
/// that produces a scalar.
69
/// think of `sum`, `mean`, `variance` like aggregations.
70
AggregatedScalar(Column),
71
/// Not yet aggregated: `agg_list` still has to be called.
72
NotAggregated(Column),
73
/// A literal scalar value.
74
LiteralScalar(Column),
75
}
76
77
impl AggState {
78
fn try_map<F>(&self, func: F) -> PolarsResult<Self>
79
where
80
F: FnOnce(&Column) -> PolarsResult<Column>,
81
{
82
Ok(match self {
83
AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
84
AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
85
AggState::LiteralScalar(c) => AggState::LiteralScalar(func(c)?),
86
AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
87
})
88
}
89
90
fn is_scalar(&self) -> bool {
91
matches!(self, Self::AggregatedScalar(_))
92
}
93
94
pub fn name(&self) -> &PlSmallStr {
95
match self {
96
AggState::AggregatedList(s)
97
| AggState::NotAggregated(s)
98
| AggState::LiteralScalar(s)
99
| AggState::AggregatedScalar(s) => s.name(),
100
}
101
}
102
103
pub fn flat_dtype(&self) -> &DataType {
104
match self {
105
AggState::AggregatedList(s) => s.dtype().inner_dtype().unwrap(),
106
AggState::NotAggregated(s)
107
| AggState::LiteralScalar(s)
108
| AggState::AggregatedScalar(s) => s.dtype(),
109
}
110
}
111
}
112
113
// lazy update strategy
114
#[derive(Debug, PartialEq, Clone, Copy)]
115
pub(crate) enum UpdateGroups {
116
/// don't update groups
117
No,
118
/// use the length of the current groups to determine new sorted indexes, preferred
119
/// for performance
120
WithGroupsLen,
121
/// use the series list offsets to determine the new group lengths
122
/// this one should be used when the length has changed. Note that
123
/// the series should be aggregated state or else it will panic.
124
WithSeriesLen,
125
}
126
127
#[cfg_attr(debug_assertions, derive(Debug))]
128
pub struct AggregationContext<'a> {
129
/// Can be in one of two states
130
/// 1. already aggregated as list
131
/// 2. flat (still needs the grouptuples to aggregate)
132
///
133
/// When aggregation state is LiteralScalar or AggregatedScalar, the group values are not
134
/// related to the state data anymore. The number of groups is still accurate.
135
pub(crate) state: AggState,
136
/// group tuples for AggState
137
pub(crate) groups: Cow<'a, GroupPositions>,
138
/// This is used to determined if we need to update the groups
139
/// into a sorted groups. We do this lazily, so that this work only is
140
/// done when the groups are needed
141
pub(crate) update_groups: UpdateGroups,
142
/// This is true when the Series and Groups still have all
143
/// their original values. Not the case when filtered
144
pub(crate) original_len: bool,
145
}
146
147
impl<'a> AggregationContext<'a> {
148
pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
149
match self.update_groups {
150
UpdateGroups::No => {},
151
UpdateGroups::WithGroupsLen => {
152
// the groups are unordered
153
// and the series is aggregated with this groups
154
// so we need to recreate new grouptuples that
155
// match the exploded Series
156
let mut offset = 0 as IdxSize;
157
158
match self.groups.as_ref().as_ref() {
159
GroupsType::Idx(groups) => {
160
let groups = groups
161
.iter()
162
.map(|g| {
163
let len = g.1.len() as IdxSize;
164
let new_offset = offset + len;
165
let out = [offset, len];
166
offset = new_offset;
167
out
168
})
169
.collect();
170
self.groups =
171
Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())
172
},
173
// sliced groups are already in correct order,
174
// Update offsets in the case of overlapping groups
175
// e.g. [0,2], [1,3], [2,4] becomes [0,2], [2,3], [5,4]
176
GroupsType::Slice { groups, .. } => {
177
// unroll
178
let groups = groups
179
.iter()
180
.map(|g| {
181
let len = g[1];
182
let new = [offset, g[1]];
183
offset += len;
184
new
185
})
186
.collect();
187
self.groups =
188
Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())
189
},
190
}
191
self.update_groups = UpdateGroups::No;
192
},
193
UpdateGroups::WithSeriesLen => {
194
let s = self.get_values().clone();
195
self.det_groups_from_list(s.as_materialized_series());
196
},
197
}
198
&self.groups
199
}
200
201
pub(crate) fn get_values(&self) -> &Column {
202
match &self.state {
203
AggState::NotAggregated(s)
204
| AggState::AggregatedScalar(s)
205
| AggState::AggregatedList(s) => s,
206
AggState::LiteralScalar(s) => s,
207
}
208
}
209
210
pub fn agg_state(&self) -> &AggState {
211
&self.state
212
}
213
214
pub(crate) fn is_not_aggregated(&self) -> bool {
215
matches!(
216
&self.state,
217
AggState::NotAggregated(_) | AggState::LiteralScalar(_)
218
)
219
}
220
221
pub(crate) fn is_aggregated(&self) -> bool {
222
!self.is_not_aggregated()
223
}
224
225
pub(crate) fn is_literal(&self) -> bool {
226
matches!(self.state, AggState::LiteralScalar(_))
227
}
228
229
/// # Arguments
230
/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
231
/// the columns dtype)
232
fn new(
233
column: Column,
234
groups: Cow<'a, GroupPositions>,
235
aggregated: bool,
236
) -> AggregationContext<'a> {
237
let series = if aggregated {
238
assert_eq!(column.len(), groups.len());
239
AggState::AggregatedScalar(column)
240
} else {
241
AggState::NotAggregated(column)
242
};
243
244
Self {
245
state: series,
246
groups,
247
update_groups: UpdateGroups::No,
248
original_len: true,
249
}
250
}
251
252
fn with_agg_state(&mut self, agg_state: AggState) {
253
self.state = agg_state;
254
}
255
256
fn from_agg_state(
257
agg_state: AggState,
258
groups: Cow<'a, GroupPositions>,
259
) -> AggregationContext<'a> {
260
Self {
261
state: agg_state,
262
groups,
263
update_groups: UpdateGroups::No,
264
original_len: true,
265
}
266
}
267
268
pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
269
self.original_len = original_len;
270
self
271
}
272
273
pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
274
self.update_groups = update;
275
self
276
}
277
278
fn det_groups_from_list(&mut self, s: &Series) {
279
let mut offset = 0 as IdxSize;
280
let list = s
281
.list()
282
.expect("impl error, should be a list at this point");
283
284
match list.chunks().len() {
285
1 => {
286
let arr = list.downcast_iter().next().unwrap();
287
let offsets = arr.offsets().as_slice();
288
289
let mut previous = 0i64;
290
let groups = offsets[1..]
291
.iter()
292
.map(|&o| {
293
let len = (o - previous) as IdxSize;
294
let new_offset = offset + len;
295
296
previous = o;
297
let out = [offset, len];
298
offset = new_offset;
299
out
300
})
301
.collect_trusted();
302
self.groups =
303
Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());
304
},
305
_ => {
306
let groups = {
307
self.get_values()
308
.list()
309
.expect("impl error, should be a list at this point")
310
.amortized_iter()
311
.map(|s| {
312
if let Some(s) = s {
313
let len = s.as_ref().len() as IdxSize;
314
let new_offset = offset + len;
315
let out = [offset, len];
316
offset = new_offset;
317
out
318
} else {
319
[offset, 0]
320
}
321
})
322
.collect_trusted()
323
};
324
self.groups =
325
Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());
326
},
327
}
328
self.update_groups = UpdateGroups::No;
329
}
330
331
/// # Arguments
332
/// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
333
/// the columns dtype)
334
pub(crate) fn with_values(
335
&mut self,
336
column: Column,
337
aggregated: bool,
338
expr: Option<&Expr>,
339
) -> PolarsResult<&mut Self> {
340
self.with_values_and_args(
341
column,
342
aggregated,
343
expr,
344
false,
345
self.agg_state().is_scalar(),
346
)
347
}
348
349
pub(crate) fn with_values_and_args(
350
&mut self,
351
column: Column,
352
aggregated: bool,
353
expr: Option<&Expr>,
354
// if the applied function was a `map` instead of an `apply`
355
// this will keep functions applied over literals as literals: F(lit) = lit
356
preserve_literal: bool,
357
returns_scalar: bool,
358
) -> PolarsResult<&mut Self> {
359
self.state = match (aggregated, column.dtype()) {
360
(true, &DataType::List(_)) if !returns_scalar => {
361
if column.len() != self.groups.len() {
362
let fmt_expr = if let Some(e) = expr {
363
format!("'{e:?}' ")
364
} else {
365
String::new()
366
};
367
polars_bail!(
368
ComputeError:
369
"aggregation expression '{}' produced a different number of elements: {} \
370
than the number of groups: {} (this is likely invalid)",
371
fmt_expr, column.len(), self.groups.len(),
372
);
373
}
374
AggState::AggregatedList(column)
375
},
376
(true, _) => AggState::AggregatedScalar(column),
377
_ => {
378
match self.state {
379
// already aggregated to sum, min even this series was flattened it never could
380
// retrieve the length before grouping, so it stays in this state.
381
AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
382
// applying a function on a literal, keeps the literal state
383
AggState::LiteralScalar(_) if column.len() == 1 && preserve_literal => {
384
AggState::LiteralScalar(column)
385
},
386
_ => AggState::NotAggregated(column.into_column()),
387
}
388
},
389
};
390
Ok(self)
391
}
392
393
pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
394
self.state = AggState::LiteralScalar(column);
395
self
396
}
397
398
/// Update the group tuples
399
pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
400
if let AggState::AggregatedList(_) = self.agg_state() {
401
// In case of new groups, a series always needs to be flattened
402
self.with_values(self.flat_naive().into_owned(), false, None)
403
.unwrap();
404
}
405
self.groups = Cow::Owned(groups);
406
// make sure that previous setting is not used
407
self.update_groups = UpdateGroups::No;
408
self
409
}
410
411
/// Ensure that each group is represented by contiguous values in memory.
412
pub fn normalize_values(&mut self) {
413
self.set_original_len(false);
414
self.groups();
415
let values = self.flat_naive();
416
let values = unsafe { values.agg_list(&self.groups) };
417
self.state = AggState::AggregatedList(values);
418
self.with_update_groups(UpdateGroups::WithGroupsLen);
419
}
420
421
/// Aggregate into `ListChunked`.
422
pub fn aggregated_as_list<'b>(&'b mut self) -> Cow<'b, ListChunked> {
423
self.aggregated();
424
let out = self.get_values();
425
match self.agg_state() {
426
AggState::AggregatedScalar(_) => Cow::Owned(out.as_list()),
427
_ => Cow::Borrowed(out.list().unwrap()),
428
}
429
}
430
431
/// Get the aggregated version of the series.
432
pub fn aggregated(&mut self) -> Column {
433
// we clone, because we only want to call `self.groups()` if needed.
434
// self groups may instantiate new groups and thus can be expensive.
435
match self.state.clone() {
436
AggState::NotAggregated(s) => {
437
// The groups are determined lazily and in case of a flat/non-aggregated
438
// series we use the groups to aggregate the list
439
// because this is lazy, we first must to update the groups
440
// by calling .groups()
441
self.groups();
442
#[cfg(debug_assertions)]
443
{
444
if self.groups.len() > s.len() {
445
polars_warn!(
446
"groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"
447
)
448
}
449
}
450
451
// SAFETY:
452
// groups are in bounds
453
let out = unsafe { s.agg_list(&self.groups) };
454
self.state = AggState::AggregatedList(out.clone());
455
456
self.update_groups = UpdateGroups::WithGroupsLen;
457
out
458
},
459
AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
460
AggState::LiteralScalar(s) => {
461
let rows = self.groups.len();
462
let s = s.implode().unwrap();
463
let s = s.new_from_index(0, rows);
464
let s = s.into_column();
465
self.state = AggState::AggregatedList(s.clone());
466
self.with_update_groups(UpdateGroups::WithSeriesLen);
467
s.clone()
468
},
469
}
470
}
471
472
/// Get the final aggregated version of the series.
473
pub fn finalize(&mut self) -> Column {
474
// we clone, because we only want to call `self.groups()` if needed.
475
// self groups may instantiate new groups and thus can be expensive.
476
match &self.state {
477
AggState::LiteralScalar(c) => {
478
let c = c.clone();
479
self.groups();
480
let rows = self.groups.len();
481
c.new_from_index(0, rows)
482
},
483
_ => self.aggregated(),
484
}
485
}
486
487
// If a binary or ternary function has both of these branches true, it should
488
// flatten the list
489
fn arity_should_explode(&self) -> bool {
490
use AggState::*;
491
match self.agg_state() {
492
LiteralScalar(s) => s.len() == 1,
493
AggregatedScalar(_) => true,
494
_ => false,
495
}
496
}
497
498
pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
499
let _ = self.groups();
500
let groups = self.groups;
501
match self.state {
502
AggState::NotAggregated(c) => (c, groups),
503
AggState::AggregatedScalar(c) => (c, groups),
504
AggState::LiteralScalar(c) => (c, groups),
505
AggState::AggregatedList(c) => {
506
let flattened = c
507
.explode(ExplodeOptions {
508
empty_as_null: false,
509
keep_nulls: true,
510
})
511
.unwrap();
512
let groups = groups.into_owned();
513
// unroll the possible flattened state
514
// say we have groups with overlapping windows:
515
//
516
// offset, len
517
// 0, 1
518
// 0, 2
519
// 0, 4
520
//
521
// gets aggregation
522
//
523
// [0]
524
// [0, 1],
525
// [0, 1, 2, 3]
526
//
527
// before aggregation the column was
528
// [0, 1, 2, 3]
529
// but explode on this list yields
530
// [0, 0, 1, 0, 1, 2, 3]
531
//
532
// so we unroll the groups as
533
//
534
// [0, 1]
535
// [1, 2]
536
// [3, 4]
537
let groups = groups.unroll();
538
(flattened, Cow::Owned(groups))
539
},
540
}
541
}
542
543
/// Get the not-aggregated version of the series.
544
/// Note that we call it naive, because if a previous expr
545
/// has filtered or sorted this, this information is in the
546
/// group tuples not the flattened series.
547
pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
548
match &self.state {
549
AggState::NotAggregated(c) => Cow::Borrowed(c),
550
AggState::AggregatedList(c) => {
551
if cfg!(debug_assertions) {
552
// Warning, so we find cases where we accidentally explode overlapping groups
553
// We don't want this as this can create a lot of data
554
if self.groups.is_overlapping() {
555
polars_warn!(
556
"performance - an aggregated list with overlapping groups may consume excessive memory"
557
)
558
}
559
}
560
561
// We should not insert nulls, otherwise the offsets in the groups will not be correct.
562
Cow::Owned(
563
c.explode(ExplodeOptions {
564
empty_as_null: false,
565
keep_nulls: true,
566
})
567
.unwrap(),
568
)
569
},
570
AggState::AggregatedScalar(c) => Cow::Borrowed(c),
571
AggState::LiteralScalar(c) => Cow::Borrowed(c),
572
}
573
}
574
575
fn flat_naive_length(&self) -> usize {
576
match &self.state {
577
AggState::NotAggregated(c) => c.len(),
578
AggState::AggregatedList(c) => c.list().unwrap().inner_length(),
579
AggState::AggregatedScalar(c) => c.len(),
580
AggState::LiteralScalar(_) => 1,
581
}
582
}
583
584
/// Take the series.
585
pub(crate) fn take(&mut self) -> Column {
586
let c = match &mut self.state {
587
AggState::NotAggregated(c)
588
| AggState::AggregatedScalar(c)
589
| AggState::AggregatedList(c) => c,
590
AggState::LiteralScalar(c) => c,
591
};
592
std::mem::take(c)
593
}
594
595
/// Do the group indices reference all values in the aggregation state.
596
fn groups_cover_all_values(&mut self) -> bool {
597
if matches!(
598
self.state,
599
AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)
600
) {
601
return true;
602
}
603
604
let num_values = self.flat_naive_length();
605
match self.groups().as_ref().as_ref() {
606
GroupsType::Idx(groups) => {
607
let mut seen = MutableBitmap::from_len_zeroed(num_values);
608
for (_, g) in groups {
609
for i in g.iter() {
610
unsafe { seen.set_unchecked(*i as usize, true) };
611
}
612
}
613
seen.unset_bits() == 0
614
},
615
GroupsType::Slice {
616
groups,
617
overlapping: true,
618
monotonic: _,
619
} => {
620
// @NOTE: Slice groups are sorted by their `start` value.
621
let mut offset = 0;
622
let mut covers_all = true;
623
for [start, length] in groups {
624
covers_all &= *start <= offset;
625
offset = start + length;
626
}
627
covers_all && offset == num_values as IdxSize
628
},
629
630
// If we don't have overlapping data, we can just do a count.
631
GroupsType::Slice {
632
groups,
633
overlapping: false,
634
monotonic: _,
635
} => groups.iter().map(|[_, l]| *l as usize).sum::<usize>() == num_values,
636
}
637
}
638
639
/// Fixes groups for `AggregatedScalar` and `LiteralScalar` so that they point to valid
640
/// data elements in the `AggState` values.
641
fn set_groups_for_undefined_agg_states(&mut self) {
642
match &self.state {
643
AggState::AggregatedList(_) | AggState::NotAggregated(_) => {},
644
AggState::AggregatedScalar(c) => {
645
assert_eq!(self.update_groups, UpdateGroups::No);
646
self.groups = Cow::Owned({
647
let groups = (0..c.len() as IdxSize).map(|i| [i, 1]).collect();
648
GroupsType::new_slice(groups, false, true).into_sliceable()
649
});
650
},
651
AggState::LiteralScalar(c) => {
652
assert_eq!(c.len(), 1);
653
assert_eq!(self.update_groups, UpdateGroups::No);
654
self.groups = Cow::Owned({
655
let groups = vec![[0, 1]; self.groups.len()];
656
GroupsType::new_slice(groups, true, true).into_sliceable()
657
});
658
},
659
}
660
}
661
662
pub fn into_static(&self) -> AggregationContext<'static> {
663
let groups: GroupPositions = GroupPositions::to_owned(&self.groups);
664
let groups: Cow<'static, GroupPositions> = Cow::Owned(groups);
665
AggregationContext {
666
state: self.state.clone(),
667
groups,
668
update_groups: self.update_groups,
669
original_len: self.original_len,
670
}
671
}
672
}
673
674
/// Take a DataFrame and evaluate the expressions.
675
/// Implement this for Column, lt, eq, etc
676
pub trait PhysicalExpr: Send + Sync {
677
fn as_expression(&self) -> Option<&Expr> {
678
None
679
}
680
681
fn as_column(&self) -> Option<PlSmallStr> {
682
None
683
}
684
685
/// Take a DataFrame and evaluate the expression.
686
fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
687
688
/// Some expression that are not aggregations can be done per group
689
/// Think of sort, slice, filter, shift, etc.
690
/// defaults to ignoring the group
691
///
692
/// This method is called by an aggregation function.
693
///
694
/// In case of a simple expr, like 'column', the groups are ignored and the column is returned.
695
/// In case of an expr where group behavior makes sense, this method is called.
696
/// For a filter operation for instance, a Series is created per groups and filtered.
697
///
698
/// An implementation of this method may apply an aggregation on the groups only. For instance
699
/// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per
700
/// group. The implementation then has to return the `Series` exploded (because a later aggregation
701
/// will use the group tuples to aggregate). The group tuples also have to be updated, because
702
/// aggregation to a list sorts the exploded `Series` by group.
703
///
704
/// This has some gotcha's. An implementation may also change the group tuples instead of
705
/// the `Series`.
706
///
707
// we allow this because we pass the vec to the Cow
708
// Note to self: Don't be smart and dispatch to evaluate as default implementation
709
// this means filters will be incorrect and lead to invalid results down the line
710
#[allow(clippy::ptr_arg)]
711
fn evaluate_on_groups<'a>(
712
&self,
713
df: &DataFrame,
714
groups: &'a GroupPositions,
715
state: &ExecutionState,
716
) -> PolarsResult<AggregationContext<'a>>;
717
718
/// Get the output field of this expr
719
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
720
721
fn is_literal(&self) -> bool {
722
false
723
}
724
fn is_scalar(&self) -> bool;
725
}
726
727
impl Display for &dyn PhysicalExpr {
728
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
729
match self.as_expression() {
730
None => Ok(()),
731
Some(e) => write!(f, "{e:?}"),
732
}
733
}
734
}
735
736
/// Wrapper struct that allow us to use a PhysicalExpr in polars-io.
737
///
738
/// This is used to filter rows during the scan of file.
739
pub struct PhysicalIoHelper {
740
pub expr: Arc<dyn PhysicalExpr>,
741
pub has_window_function: bool,
742
}
743
744
impl PhysicalIoExpr for PhysicalIoHelper {
745
fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
746
let mut state: ExecutionState = Default::default();
747
if self.has_window_function {
748
state.insert_has_window_function_flag();
749
}
750
self.expr.evaluate(df, &state).map(|c| {
751
// IO expression result should be boolean-typed.
752
debug_assert_eq!(c.dtype(), &DataType::Boolean);
753
(if c.len() == 1 && df.height() != 1 {
754
// filter(lit(True)) will hit here.
755
c.new_from_index(0, df.height())
756
} else {
757
c
758
})
759
.take_materialized_series()
760
})
761
}
762
}
763
764
pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
765
let has_window_function = if let Some(expr) = expr.as_expression() {
766
expr.into_iter().any(|expr| {
767
#[cfg(feature = "dynamic_group_by")]
768
if matches!(expr, Expr::Rolling { .. }) {
769
return true;
770
}
771
772
matches!(expr, Expr::Over { .. })
773
})
774
} else {
775
false
776
};
777
Arc::new(PhysicalIoHelper {
778
expr,
779
has_window_function,
780
}) as Arc<dyn PhysicalIoExpr>
781
}
782
783