Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/apply.rs
8424 views
1
use std::borrow::Cow;
2
3
use polars_core::POOL;
4
use polars_core::chunked_array::builder::get_list_builder;
5
use polars_core::chunked_array::from_iterator_par::{
6
ChunkedCollectParIterExt, try_list_from_par_iter,
7
};
8
use polars_core::prelude::*;
9
use rayon::prelude::*;
10
11
use super::*;
12
use crate::dispatch::GroupsUdf;
13
use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
14
15
#[derive(Clone)]
16
pub struct ApplyExpr {
17
inputs: Vec<Arc<dyn PhysicalExpr>>,
18
function: SpecialEq<Arc<dyn ColumnsUdf>>,
19
groups_function: Option<SpecialEq<Arc<dyn GroupsUdf>>>,
20
expr: Expr,
21
flags: FunctionFlags,
22
function_operates_on_scalar: bool,
23
input_schema: SchemaRef,
24
allow_threading: bool,
25
check_lengths: bool,
26
is_fallible: bool,
27
28
/// Output field of the expression excluding potential aggregation.
29
output_field: Field,
30
}
31
32
impl ApplyExpr {
33
#[allow(clippy::too_many_arguments)]
34
pub(crate) fn new(
35
inputs: Vec<Arc<dyn PhysicalExpr>>,
36
function: SpecialEq<Arc<dyn ColumnsUdf>>,
37
groups_function: Option<SpecialEq<Arc<dyn GroupsUdf>>>,
38
expr: Expr,
39
options: FunctionOptions,
40
allow_threading: bool,
41
input_schema: SchemaRef,
42
non_aggregated_output_field: Field,
43
function_operates_on_scalar: bool,
44
is_fallible: bool,
45
) -> Self {
46
debug_assert!(
47
!options.is_length_preserving()
48
|| !options.flags.contains(FunctionFlags::RETURNS_SCALAR),
49
"expr {expr:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive",
50
);
51
52
Self {
53
inputs,
54
function,
55
groups_function,
56
expr,
57
flags: options.flags,
58
function_operates_on_scalar,
59
input_schema,
60
allow_threading,
61
check_lengths: options.check_lengths(),
62
output_field: non_aggregated_output_field,
63
is_fallible,
64
}
65
}
66
67
#[allow(clippy::ptr_arg)]
68
fn prepare_multiple_inputs<'a>(
69
&self,
70
df: &DataFrame,
71
groups: &'a GroupPositions,
72
state: &ExecutionState,
73
) -> PolarsResult<Vec<AggregationContext<'a>>> {
74
let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate_on_groups(df, groups, state);
75
if self.allow_threading {
76
POOL.install(|| self.inputs.par_iter().map(f).collect())
77
} else {
78
self.inputs.iter().map(f).collect()
79
}
80
}
81
82
fn finish_apply_groups<'a>(
83
&self,
84
mut ac: AggregationContext<'a>,
85
ca: ListChunked,
86
) -> PolarsResult<AggregationContext<'a>> {
87
let c = if self.is_scalar() {
88
let out = ca
89
.explode(ExplodeOptions {
90
empty_as_null: true,
91
keep_nulls: true,
92
})
93
.unwrap();
94
// if the explode doesn't return the same len, it wasn't scalar.
95
polars_ensure!(out.len() == ca.len(), InvalidOperation: "expected scalar for expr: {}, got {}", self.expr, &out);
96
ac.update_groups = UpdateGroups::No;
97
out.into_column()
98
} else {
99
ac.with_update_groups(UpdateGroups::WithSeriesLen);
100
ca.into_series().into()
101
};
102
103
ac.with_values_and_args(c, true, None, false, self.is_scalar())?;
104
105
Ok(ac)
106
}
107
108
fn get_input_schema(&self, _df: &DataFrame) -> Cow<'_, Schema> {
109
Cow::Borrowed(self.input_schema.as_ref())
110
}
111
112
/// Evaluates and flattens `Option<Column>` to `Column`.
113
fn eval_and_flatten(&self, inputs: &mut [Column]) -> PolarsResult<Column> {
114
self.function.call_udf(inputs)
115
}
116
117
fn apply_single_group_aware<'a>(
118
&self,
119
mut ac: AggregationContext<'a>,
120
) -> PolarsResult<AggregationContext<'a>> {
121
// Fix up groups for AggregatedScalar, so that we can pretend they are just normal groups.
122
ac.set_groups_for_undefined_agg_states();
123
124
let name = ac.get_values().name().clone();
125
let f = |opt_s: Option<Series>| match opt_s {
126
None => Ok(None),
127
Some(mut s) => {
128
if self.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY) {
129
s.rename(name.clone());
130
}
131
Ok(Some(
132
self.function
133
.call_udf(&mut [Column::from(s)])?
134
.take_materialized_series(),
135
))
136
},
137
};
138
139
// In case of overlapping (rolling) groups, we build groups in a lazy manner to avoid
140
// memory explosion.
141
// TODO: support Idx GroupsType.
142
if matches!(ac.agg_state(), AggState::NotAggregated(_)) && ac.groups.is_overlapping() {
143
let ca: ChunkedArray<_> = if self.allow_threading {
144
ac.par_iter_groups_lazy()
145
.map(f)
146
.collect::<PolarsResult<_>>()?
147
} else {
148
ac.iter_groups_lazy().map(f).collect::<PolarsResult<_>>()?
149
};
150
return self.finish_apply_groups(ac, ca.with_name(name));
151
}
152
153
// At this point, calling aggregated() will not lead to memory explosion.
154
let agg = match ac.agg_state() {
155
AggState::AggregatedScalar(s) => s.as_list().into_column(),
156
_ => ac.aggregated(),
157
};
158
159
// Collection of empty list leads to a null dtype. See: #3687.
160
if agg.is_empty() {
161
// Create input for the function to determine the output dtype, see #3946.
162
let agg = agg.list().unwrap();
163
let input_dtype = agg.inner_dtype();
164
let input = Column::full_null(name.clone(), 0, input_dtype);
165
166
let output = self.eval_and_flatten(&mut [input])?;
167
let ca = ListChunked::full(name, output.as_materialized_series(), 0);
168
return self.finish_apply_groups(ac, ca);
169
}
170
171
let ca: ListChunked = if self.allow_threading {
172
let lst = agg.list().unwrap();
173
let iter = lst.par_iter().map(f);
174
175
if self.output_field.dtype.is_known() {
176
let dtype = self.output_field.dtype.clone();
177
let dtype = dtype.implode();
178
POOL.install(|| {
179
iter.collect_ca_with_dtype::<PolarsResult<_>>(PlSmallStr::EMPTY, dtype)
180
})?
181
} else {
182
POOL.install(|| try_list_from_par_iter(iter, PlSmallStr::EMPTY))?
183
}
184
} else {
185
agg.list()
186
.unwrap()
187
.into_iter()
188
.map(f)
189
.collect::<PolarsResult<_>>()?
190
};
191
192
self.finish_apply_groups(ac, ca.with_name(name))
193
}
194
195
/// Apply elementwise e.g. ignore the group/list indices.
196
fn apply_single_elementwise<'a>(
197
&self,
198
mut ac: AggregationContext<'a>,
199
) -> PolarsResult<AggregationContext<'a>> {
200
let (c, aggregated) = match ac.agg_state() {
201
AggState::AggregatedList(c) => {
202
let ca = c.list().unwrap();
203
let out = ca.apply_to_inner(&|s| {
204
Ok(self
205
.eval_and_flatten(&mut [s.into_column()])?
206
.take_materialized_series())
207
})?;
208
(out.into_column(), true)
209
},
210
AggState::NotAggregated(c) => {
211
let (out, aggregated) = (self.eval_and_flatten(&mut [c.clone()])?, false);
212
check_map_output_len(c.len(), out.len(), &self.expr)?;
213
(out, aggregated)
214
},
215
agg_state => {
216
ac.with_agg_state(agg_state.try_map(|s| self.eval_and_flatten(&mut [s.clone()]))?);
217
return Ok(ac);
218
},
219
};
220
221
ac.with_values_and_args(c, aggregated, Some(&self.expr), true, self.is_scalar())?;
222
Ok(ac)
223
}
224
225
// Fast-path when every AggState is a LiteralScalar. This path avoids calling aggregated() or
226
// groups(), and returns a LiteralScalar, on the implicit condition that the function is pure.
227
fn apply_all_literal_elementwise<'a>(
228
&self,
229
mut acs: Vec<AggregationContext<'a>>,
230
) -> PolarsResult<AggregationContext<'a>> {
231
let mut cols = acs
232
.iter()
233
.map(|ac| ac.get_values().clone())
234
.collect::<Vec<_>>();
235
let out = self.function.call_udf(&mut cols)?;
236
polars_ensure!(
237
out.len() == 1,
238
ComputeError: "elementwise expression {:?} must return exactly 1 value on literals, got {}",
239
&self.expr, out.len()
240
);
241
let mut ac = acs.pop().unwrap();
242
ac.with_literal(out);
243
Ok(ac)
244
}
245
246
fn apply_multiple_elementwise<'a>(
247
&self,
248
mut acs: Vec<AggregationContext<'a>>,
249
must_aggregate: bool,
250
) -> PolarsResult<AggregationContext<'a>> {
251
// At this stage, we either have (with or without LiteralScalars):
252
// - one or more AggregatedList or NotAggregated ACs
253
// - one or more AggregatedScalar ACs
254
255
let mut previous = None;
256
for ac in acs.iter_mut() {
257
// TBD: If we want to be strict, we would check all groups
258
if matches!(
259
ac.state,
260
AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)
261
) {
262
continue;
263
}
264
265
if must_aggregate {
266
ac.aggregated();
267
}
268
269
if matches!(ac.state, AggState::AggregatedList(_)) {
270
if let Some(p) = previous {
271
ac.groups().check_lengths(p)?;
272
}
273
previous = Some(ac.groups());
274
}
275
}
276
277
// At this stage, we do not have both AggregatedList and NotAggregated ACs
278
279
// The first non-LiteralScalar AC will be used as the base AC to retain the context
280
let base_ac_idx = acs.iter().position(|ac| !ac.is_literal()).unwrap();
281
282
match acs[base_ac_idx].agg_state() {
283
AggState::AggregatedList(s) => {
284
let aggregated = acs.iter().any(|ac| ac.is_aggregated());
285
let ca = s.list().unwrap();
286
let input_len = s.len();
287
288
let out = ca.apply_to_inner(&|_| {
289
let mut cols = acs
290
.iter()
291
.map(|ac| ac.flat_naive().into_owned())
292
.collect::<Vec<_>>();
293
Ok(self
294
.function
295
.call_udf(&mut cols)?
296
.as_materialized_series()
297
.clone())
298
})?;
299
300
let out = out.into_column();
301
if self.check_lengths {
302
check_map_output_len(input_len, out.len(), &self.expr)?;
303
}
304
305
let mut ac = acs.swap_remove(base_ac_idx);
306
ac.with_values_and_args(
307
out,
308
aggregated,
309
Some(&self.expr),
310
false,
311
self.is_scalar(),
312
)?;
313
Ok(ac)
314
},
315
_ => {
316
let aggregated = acs.iter().any(|ac| ac.is_aggregated());
317
debug_assert!(aggregated == self.is_scalar());
318
319
let mut cols = acs
320
.iter()
321
.map(|ac| ac.flat_naive().into_owned())
322
.collect::<Vec<_>>();
323
324
let input_len = cols[base_ac_idx].len();
325
let out = self.function.call_udf(&mut cols)?;
326
if self.check_lengths {
327
check_map_output_len(input_len, out.len(), &self.expr)?;
328
}
329
330
let mut ac = acs.swap_remove(base_ac_idx);
331
ac.with_values_and_args(
332
out,
333
aggregated,
334
Some(&self.expr),
335
false,
336
self.is_scalar(),
337
)?;
338
Ok(ac)
339
},
340
}
341
}
342
343
fn apply_multiple_group_aware<'a>(
344
&self,
345
mut acs: Vec<AggregationContext<'a>>,
346
df: &DataFrame,
347
) -> PolarsResult<AggregationContext<'a>> {
348
let mut container = vec![Default::default(); acs.len()];
349
let schema = self.get_input_schema(df);
350
let field = self.to_field(&schema)?;
351
352
// Aggregate representation of the aggregation contexts,
353
// then unpack the lists and finally create iterators from this list chunked arrays.
354
let mut iters = acs
355
.iter_mut()
356
.map(|ac| ac.iter_groups(self.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY)))
357
.collect::<Vec<_>>();
358
359
// Length of the items to iterate over.
360
let len = iters[0].size_hint().0;
361
362
let ca = if field.dtype().is_known() {
363
let mut builder = get_list_builder(&field.dtype, len * 5, len, field.name);
364
for _ in 0..len {
365
container.clear();
366
for iter in &mut iters {
367
match iter.next().unwrap() {
368
None => {
369
builder.append_null();
370
},
371
Some(s) => container.push(s.deep_clone().into()),
372
}
373
}
374
let out = self
375
.function
376
.call_udf(&mut container)
377
.map(|c| c.take_materialized_series())?;
378
379
builder.append_series(&out)?
380
}
381
builder.finish()
382
} else {
383
// We still need this branch to materialize unknown/ data dependent types in eager. :(
384
(0..len)
385
.map(|_| {
386
container.clear();
387
for iter in &mut iters {
388
match iter.next().unwrap() {
389
None => return Ok(None),
390
Some(s) => container.push(s.deep_clone().into()),
391
}
392
}
393
Ok(Some(
394
self.function
395
.call_udf(&mut container)?
396
.take_materialized_series(),
397
))
398
})
399
.collect::<PolarsResult<ListChunked>>()?
400
.with_name(field.name.clone())
401
};
402
#[cfg(debug_assertions)]
403
{
404
let inner = ca.dtype().inner_dtype().unwrap();
405
if field.dtype.is_known() {
406
assert_eq!(inner, &field.dtype);
407
}
408
}
409
410
drop(iters);
411
412
// Take the first aggregation context that as that is the input series.
413
let ac = acs.swap_remove(0);
414
self.finish_apply_groups(ac, ca)
415
}
416
}
417
418
fn check_map_output_len(input_len: usize, output_len: usize, expr: &Expr) -> PolarsResult<()> {
419
polars_ensure!(
420
input_len == output_len, expr = expr, InvalidOperation:
421
"output length of `map` ({}) must be equal to the input length ({}); \
422
consider using `apply` instead", output_len, input_len
423
);
424
Ok(())
425
}
426
427
impl PhysicalExpr for ApplyExpr {
428
fn as_expression(&self) -> Option<&Expr> {
429
Some(&self.expr)
430
}
431
432
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
433
let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate(df, state);
434
let mut inputs = if self.allow_threading && self.inputs.len() > 1 {
435
POOL.install(|| {
436
self.inputs
437
.par_iter()
438
.map(f)
439
.collect::<PolarsResult<Vec<_>>>()
440
})
441
} else {
442
self.inputs.iter().map(f).collect::<PolarsResult<Vec<_>>>()
443
}?;
444
445
if self.flags.contains(FunctionFlags::ALLOW_RENAME) {
446
self.eval_and_flatten(&mut inputs)
447
} else {
448
let in_name = inputs[0].name().clone();
449
Ok(self.eval_and_flatten(&mut inputs)?.with_name(in_name))
450
}
451
}
452
453
#[allow(clippy::ptr_arg)]
454
fn evaluate_on_groups<'a>(
455
&self,
456
df: &DataFrame,
457
groups: &'a GroupPositions,
458
state: &ExecutionState,
459
) -> PolarsResult<AggregationContext<'a>> {
460
// Some function have specialized implementation.
461
if let Some(groups_function) = self.groups_function.as_ref() {
462
return groups_function.evaluate_on_groups(&self.inputs, df, groups, state);
463
}
464
465
if self.inputs.len() == 1 {
466
let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;
467
468
if self.flags.is_elementwise() && (!self.is_fallible || ac.groups_cover_all_values()) {
469
self.apply_single_elementwise(ac)
470
} else {
471
self.apply_single_group_aware(ac)
472
}
473
} else {
474
let mut acs = self.prepare_multiple_inputs(df, groups, state)?;
475
476
match self.flags.is_elementwise() {
477
false => self.apply_multiple_group_aware(acs, df),
478
true => {
479
// Implementation dispatch:
480
// The current implementation of `apply_multiple_elementwise` requires the
481
// multiple inputs to have a compatible data layout as it invokes `flat_naive()`.
482
// Compatible means matching as-is, or possibly matching after aggregation,
483
// or matching after an implicit broadcast by the function.
484
485
// The dispatch logic between the implementations depends on the combination of aggstates:
486
// - Any presence of LiteralScalar is immaterial as it gets broadcasted in the UDF.
487
// - Combination of AggregatedScalar and AggregatedList => NOT compatible.
488
// - Combination of AggregatedScalar and NotAggregated => NOT compatible.
489
// - Any other combination => comptable, and thereforee allowed for elementwise.
490
// In this case, aggregated() on NotAggregated may be required; however, it can be
491
// prohibitively memory expensive when dealing with overlapping (e.g., rolling) groups,
492
// in which case we fall-back to group_aware.
493
494
// Consequently, these may follow the elementwise path (not exhaustive):
495
// - All AggregatedScalar
496
// - A combination of AggregatedList(s) and NotAggregated(s) without expensive aggregation.
497
// - Either of the above with or without LiteralScalar
498
499
// Visually, in the case of 2 aggstates:
500
// Legend:
501
// - el = elementwise, no need to aggregate() NotAgg
502
// - el + agg = elementwise, but must aggregate() NotAgg
503
// - ga = group_aware
504
// - alit = all_literal
505
// - * = broadcast falls back to group_aware
506
// - ~ = same a smirror pair (symmetric)
507
//
508
// | AggList | NotAgg | AggScalar | LitScalar
509
// --------------------------------------------------------
510
// AggList | el* | depends* | ga | el
511
// NotAgg | ~ | depends* | ga | el
512
// AggScalar | ~ | ~ | el | el
513
// LitScalar | ~ | ~ | ~ | alit
514
//
515
// In case it depends, extending to any combination of multiple aggstates
516
// (a) Multiple NotAggs, w/o AggList
517
//
518
// | !has_rolling | has_rolling
519
// -------------------------------------------------
520
// groups match | el | el
521
// groups_diverge | el+agg | ga
522
//
523
// (b) Multiple NotAggs, with at least 1 AggList
524
//
525
// | !has_rolling | has_rolling
526
// -------------------------------------------------
527
// groups match | el+agg | ga
528
// groups diverge | el+agg | ga
529
//
530
// * Finally, when broadcast is required in non-scalar we switch to group_aware
531
532
// Collect statistics on input aggstates
533
let mut has_agg_list = false;
534
let mut has_agg_scalar = false;
535
let mut has_not_agg = false;
536
let mut has_not_agg_with_overlapping_groups = false;
537
let mut not_agg_groups_may_diverge = false;
538
539
let mut previous: Option<&AggregationContext<'_>> = None;
540
for ac in &acs {
541
match ac.state {
542
AggState::AggregatedList(_) => {
543
has_agg_list = true;
544
},
545
AggState::AggregatedScalar(_) => has_agg_scalar = true,
546
AggState::NotAggregated(_) => {
547
has_not_agg = true;
548
if let Some(p) = previous {
549
not_agg_groups_may_diverge |=
550
!std::ptr::eq(p.groups.as_ref(), ac.groups.as_ref());
551
}
552
previous = Some(ac);
553
if ac.groups.is_overlapping() {
554
has_not_agg_with_overlapping_groups = true;
555
}
556
},
557
_ => {},
558
}
559
}
560
561
let all_literal = !(has_agg_list || has_agg_scalar || has_not_agg);
562
let elementwise_must_aggregate =
563
has_not_agg && (has_agg_list || not_agg_groups_may_diverge);
564
565
if all_literal {
566
// Fast path
567
self.apply_all_literal_elementwise(acs)
568
} else if has_agg_scalar && (has_agg_list || has_not_agg) {
569
// Not compatible
570
self.apply_multiple_group_aware(acs, df)
571
} else if elementwise_must_aggregate && has_not_agg_with_overlapping_groups {
572
// Compatible but calling aggregated() is too expensive
573
self.apply_multiple_group_aware(acs, df)
574
} else if self.is_fallible
575
&& acs.iter_mut().any(|ac| !ac.groups_cover_all_values())
576
{
577
// Fallible expression and there are elements that are masked out.
578
self.apply_multiple_group_aware(acs, df)
579
} else {
580
// Broadcast in NotAgg or AggList requires group_aware
581
acs.iter_mut().filter(|ac| !ac.is_literal()).for_each(|ac| {
582
ac.groups();
583
});
584
let has_broadcast =
585
if let Some(base_ac_idx) = acs.iter().position(|ac| !ac.is_literal()) {
586
acs.iter()
587
.enumerate()
588
.filter(|(i, ac)| *i != base_ac_idx && !ac.is_literal())
589
.any(|(_, ac)| {
590
acs[base_ac_idx].groups.iter().zip(ac.groups.iter()).any(
591
|(l, r)| {
592
l.len() != r.len() && (l.len() == 1 || r.len() == 1)
593
},
594
)
595
})
596
} else {
597
false
598
};
599
if has_broadcast {
600
// Broadcast fall-back.
601
self.apply_multiple_group_aware(acs, df)
602
} else {
603
self.apply_multiple_elementwise(acs, elementwise_must_aggregate)
604
}
605
}
606
},
607
}
608
}
609
}
610
611
fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
612
Ok(self.output_field.clone())
613
}
614
fn is_scalar(&self) -> bool {
615
self.flags.returns_scalar()
616
|| (self.function_operates_on_scalar && self.flags.is_length_preserving())
617
}
618
}
619
620