Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/structeval.rs
8406 views
1
use std::sync::Arc;
2
3
use polars_core::POOL;
4
use polars_core::error::{PolarsResult, polars_ensure};
5
use polars_core::frame::DataFrame;
6
use polars_core::prelude::*;
7
use polars_core::schema::Schema;
8
use polars_plan::dsl::Expr;
9
use rayon::prelude::*;
10
11
use super::PhysicalExpr;
12
#[cfg(feature = "dtype-struct")]
13
use crate::dispatch::struct_::with_fields;
14
use crate::prelude::{AggState, AggregationContext, UpdateGroups};
15
use crate::state::ExecutionState;
16
17
#[derive(Clone)]
18
pub struct StructEvalExpr {
19
input: Arc<dyn PhysicalExpr>,
20
evaluation: Vec<Arc<dyn PhysicalExpr>>,
21
expr: Expr,
22
output_field: Field,
23
operates_on_scalar: bool,
24
allow_threading: bool,
25
}
26
27
impl StructEvalExpr {
28
pub(crate) fn new(
29
input: Arc<dyn PhysicalExpr>,
30
evaluation: Vec<Arc<dyn PhysicalExpr>>,
31
expr: Expr,
32
output_field: Field,
33
operates_on_scalar: bool,
34
allow_threading: bool,
35
) -> Self {
36
Self {
37
input,
38
evaluation,
39
expr,
40
output_field,
41
operates_on_scalar,
42
allow_threading,
43
}
44
}
45
}
46
47
impl StructEvalExpr {
48
fn apply_all_literal_elementwise<'a>(
49
&self,
50
mut acs: Vec<AggregationContext<'a>>,
51
) -> PolarsResult<AggregationContext<'a>> {
52
let cols = acs
53
.iter()
54
.map(|ac| ac.get_values().clone())
55
.collect::<Vec<_>>();
56
let out = with_fields(&cols)?;
57
polars_ensure!(
58
out.len() == 1,
59
ComputeError: "elementwise expression {:?} must return exactly 1 value on literals, got {}",
60
&self.expr, out.len()
61
);
62
let mut ac = acs.pop().unwrap();
63
ac.with_literal(out);
64
Ok(ac)
65
}
66
67
fn apply_elementwise<'a>(
68
&self,
69
mut acs: Vec<AggregationContext<'a>>,
70
must_aggregate: bool,
71
) -> PolarsResult<AggregationContext<'a>> {
72
// At this stage, we either have (with or without LiteralScalars):
73
// - one or more AggregatedList or NotAggregated ACs
74
// - one or more AggregatedScalar ACs
75
76
let mut previous = None;
77
for ac in acs.iter_mut() {
78
if matches!(
79
ac.state,
80
AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)
81
) {
82
continue;
83
}
84
85
if must_aggregate {
86
ac.aggregated();
87
}
88
89
if matches!(ac.state, AggState::AggregatedList(_)) {
90
if let Some(p) = previous {
91
ac.groups().check_lengths(p)?;
92
}
93
previous = Some(ac.groups());
94
}
95
}
96
97
// At this stage, we do not have both AggregatedList and NotAggregated ACs
98
99
// The first AC represents the `input` and will be used as the base AC.
100
let base_ac_idx = 0;
101
102
match acs[base_ac_idx].agg_state() {
103
AggState::AggregatedList(s) => {
104
let aggregated = acs.iter().any(|ac| ac.is_aggregated());
105
let ca = s.list().unwrap();
106
let input_len = s.len();
107
108
let out = ca.apply_to_inner(&|_| {
109
let cols = acs
110
.iter()
111
.map(|ac| ac.flat_naive().into_owned())
112
.collect::<Vec<_>>();
113
Ok(with_fields(&cols)?.as_materialized_series().clone())
114
})?;
115
116
let out = out.into_column();
117
assert!(input_len == out.len());
118
119
let mut ac = acs.swap_remove(base_ac_idx);
120
ac.with_values_and_args(
121
out,
122
aggregated,
123
Some(&self.expr),
124
false,
125
self.is_scalar(),
126
)?;
127
Ok(ac)
128
},
129
_ => {
130
let aggregated = acs.iter().any(|ac| ac.is_aggregated());
131
assert!(aggregated == self.is_scalar());
132
133
let cols = acs
134
.iter()
135
.map(|ac| ac.flat_naive().into_owned())
136
.collect::<Vec<_>>();
137
138
let input_len = cols[base_ac_idx].len();
139
let out = with_fields(&cols)?;
140
assert!(input_len == out.len());
141
142
let mut ac = acs.swap_remove(base_ac_idx);
143
ac.with_values_and_args(
144
out,
145
aggregated,
146
Some(&self.expr),
147
false,
148
self.is_scalar(),
149
)?;
150
Ok(ac)
151
},
152
}
153
}
154
155
fn apply_group_aware<'a>(
156
&self,
157
mut acs: Vec<AggregationContext<'a>>,
158
) -> PolarsResult<AggregationContext<'a>> {
159
let len = acs[0].groups.len();
160
let mut iters = acs
161
.iter_mut()
162
.map(|ac| ac.iter_groups(true))
163
.collect::<Vec<_>>();
164
let ca = (0..len)
165
.map(|_| {
166
let mut cols = Vec::with_capacity(iters.len());
167
for i in &mut iters {
168
match i.next().unwrap() {
169
None => return Ok(None),
170
Some(s) => cols.push(s.as_ref().clone().into_column()),
171
}
172
}
173
let out = with_fields(&cols)?;
174
Ok(Some(out))
175
})
176
.collect::<PolarsResult<ListChunked>>()?;
177
drop(iters);
178
179
// Finish apply groups; see also ApplyExpr for the reference solution.
180
let ac = acs.swap_remove(0);
181
self.finish_apply_groups(ac, ca)
182
}
183
184
fn finish_apply_groups<'a>(
185
&self,
186
mut ac: AggregationContext<'a>,
187
ca: ListChunked,
188
) -> PolarsResult<AggregationContext<'a>> {
189
let col = if self.is_scalar() {
190
let out = ca
191
.explode(ExplodeOptions {
192
empty_as_null: true,
193
keep_nulls: true,
194
})
195
.unwrap();
196
// if the explode doesn't return the same len, it wasn't scalar.
197
polars_ensure!(out.len() == ca.len(), InvalidOperation: "expected scalar for expr: {}, got {}", self.expr, &out);
198
ac.update_groups = UpdateGroups::No;
199
out.into_column()
200
} else {
201
ac.with_update_groups(UpdateGroups::WithSeriesLen);
202
ca.into_series().into()
203
};
204
205
ac.with_values_and_args(col, true, self.as_expression(), false, self.is_scalar())?;
206
207
Ok(ac)
208
}
209
}
210
211
impl PhysicalExpr for StructEvalExpr {
212
fn as_expression(&self) -> Option<&Expr> {
213
Some(&self.expr)
214
}
215
216
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
217
let input = self.input.evaluate(df, state)?;
218
219
// Set ExecutionState.
220
let mut state = state.clone();
221
let mut eval = Vec::with_capacity(self.evaluation.len() + 1);
222
let input_len = input.len();
223
224
state.with_fields = Some(Arc::new(input.struct_()?.clone()));
225
226
// Collect evaluation fields; input goes first.
227
eval.push(input);
228
229
let f = |e: &Arc<dyn PhysicalExpr>| {
230
let result = e.evaluate(df, &state)?;
231
polars_ensure!(
232
result.len() == input_len || result.len() == 1,
233
ShapeMismatch: "struct.with_fields expressions must have matching or unit length"
234
);
235
Ok(result)
236
};
237
let cols = if self.allow_threading {
238
POOL.install(|| {
239
self.evaluation
240
.par_iter()
241
.map(f)
242
.collect::<PolarsResult<Vec<_>>>()
243
})
244
} else {
245
self.evaluation
246
.iter()
247
.map(f)
248
.collect::<PolarsResult<Vec<_>>>()
249
};
250
for col in cols? {
251
eval.push(col);
252
}
253
254
// Apply with_fields.
255
with_fields(&eval)
256
}
257
258
fn evaluate_on_groups<'a>(
259
&self,
260
df: &DataFrame,
261
groups: &'a GroupPositions,
262
state: &ExecutionState,
263
) -> PolarsResult<AggregationContext<'a>> {
264
// The evaluation is similar to a regular Function, with the modification that the input
265
// is evaluated first, and retained for future use in the ExecutionState.
266
267
// Evaluate input.
268
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
269
270
ac.groups();
271
ac.set_groups_for_undefined_agg_states();
272
273
// Snap the AC into the ExecutionState for re-use when Field is evaluated.
274
let mut state = state.clone();
275
state.with_fields_ac = Some(Arc::new(ac.into_static()));
276
277
// Collect evaluation fields.
278
let mut acs = Vec::with_capacity(self.evaluation.len() + 1);
279
acs.push(ac);
280
281
let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate_on_groups(df, groups, &state);
282
let acs_eval = if self.allow_threading {
283
POOL.install(|| {
284
self.evaluation
285
.par_iter()
286
.map(f)
287
.collect::<PolarsResult<Vec<_>>>()
288
})
289
} else {
290
self.evaluation
291
.iter()
292
.map(f)
293
.collect::<PolarsResult<Vec<_>>>()
294
};
295
for ac in acs_eval? {
296
acs.push(ac)
297
}
298
299
// Revert ExecutionState.
300
state.with_fields_ac = None;
301
302
// Merge the `evaluation` back into the `input` struct.
303
// @NOTE. From this point on, we are dealing with a regular Function `with_fields`, which is
304
// elementwise top-level and not fallible. We leverage the reference dispatch for ApplyExpr,
305
// but simplified.
306
307
// Collect statistics on input aggstates
308
let mut has_agg_list = false;
309
let mut has_agg_scalar = false;
310
let mut has_not_agg = false;
311
let mut has_not_agg_with_overlapping_groups = false;
312
let mut not_agg_groups_may_diverge = false;
313
314
let mut previous: Option<&AggregationContext<'_>> = None;
315
for ac in &acs {
316
match ac.state {
317
AggState::AggregatedList(_) => {
318
has_agg_list = true;
319
},
320
AggState::AggregatedScalar(_) => has_agg_scalar = true,
321
AggState::NotAggregated(_) => {
322
has_not_agg = true;
323
if let Some(p) = previous {
324
not_agg_groups_may_diverge |= !p.groups.is_same(&ac.groups)
325
}
326
previous = Some(ac);
327
if ac.groups.is_overlapping() {
328
has_not_agg_with_overlapping_groups = true;
329
}
330
},
331
AggState::LiteralScalar(_) => {},
332
}
333
}
334
335
let all_literal = !(has_agg_list || has_agg_scalar || has_not_agg);
336
let elementwise_must_aggregate =
337
has_not_agg && (has_agg_list || not_agg_groups_may_diverge);
338
339
if all_literal {
340
// Fast path
341
self.apply_all_literal_elementwise(acs)
342
} else if has_agg_scalar && (has_agg_list || has_not_agg) {
343
// Not compatible
344
self.apply_group_aware(acs)
345
} else if elementwise_must_aggregate && has_not_agg_with_overlapping_groups {
346
// Compatible but calling aggregated() is too expensive
347
self.apply_group_aware(acs)
348
} else {
349
// Broadcast in NotAgg or AggList requires group_aware
350
acs.iter_mut().filter(|ac| !ac.is_literal()).for_each(|ac| {
351
ac.groups();
352
});
353
let has_broadcast =
354
if let Some(base_ac_idx) = acs.iter().position(|ac| !ac.is_literal()) {
355
acs.iter()
356
.enumerate()
357
.filter(|(i, ac)| *i != base_ac_idx && !ac.is_literal())
358
.any(|(_, ac)| {
359
acs[base_ac_idx]
360
.groups
361
.iter()
362
.zip(ac.groups.iter())
363
.any(|(l, r)| l.len() != r.len() && (l.len() == 1 || r.len() == 1))
364
})
365
} else {
366
false
367
};
368
if has_broadcast {
369
// Broadcast fall-back.
370
self.apply_group_aware(acs)
371
} else {
372
self.apply_elementwise(acs, elementwise_must_aggregate)
373
}
374
}
375
}
376
377
fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
378
Ok(self.output_field.clone())
379
}
380
381
fn is_scalar(&self) -> bool {
382
self.operates_on_scalar
383
}
384
}
385
386