Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/apply.rs
6940 views
1
use std::borrow::Cow;
2
3
use polars_core::POOL;
4
use polars_core::chunked_array::builder::get_list_builder;
5
use polars_core::chunked_array::from_iterator_par::{
6
ChunkedCollectParIterExt, try_list_from_par_iter,
7
};
8
use polars_core::prelude::*;
9
use rayon::prelude::*;
10
11
use super::*;
12
use crate::expressions::{
13
AggState, AggregationContext, PartitionedAggregation, PhysicalExpr, UpdateGroups,
14
};
15
16
#[derive(Clone)]
17
pub struct ApplyExpr {
18
inputs: Vec<Arc<dyn PhysicalExpr>>,
19
function: SpecialEq<Arc<dyn ColumnsUdf>>,
20
expr: Expr,
21
flags: FunctionFlags,
22
function_operates_on_scalar: bool,
23
input_schema: SchemaRef,
24
allow_threading: bool,
25
check_lengths: bool,
26
output_field: Field,
27
}
28
29
impl ApplyExpr {
30
#[allow(clippy::too_many_arguments)]
31
pub(crate) fn new(
32
inputs: Vec<Arc<dyn PhysicalExpr>>,
33
function: SpecialEq<Arc<dyn ColumnsUdf>>,
34
expr: Expr,
35
options: FunctionOptions,
36
allow_threading: bool,
37
input_schema: SchemaRef,
38
output_field: Field,
39
function_operates_on_scalar: bool,
40
) -> Self {
41
debug_assert!(
42
!options.is_length_preserving()
43
|| !options.flags.contains(FunctionFlags::RETURNS_SCALAR),
44
"expr {expr:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive",
45
);
46
47
Self {
48
inputs,
49
function,
50
expr,
51
flags: options.flags,
52
function_operates_on_scalar,
53
input_schema,
54
allow_threading,
55
check_lengths: options.check_lengths(),
56
output_field,
57
}
58
}
59
60
#[allow(clippy::ptr_arg)]
61
fn prepare_multiple_inputs<'a>(
62
&self,
63
df: &DataFrame,
64
groups: &'a GroupPositions,
65
state: &ExecutionState,
66
) -> PolarsResult<Vec<AggregationContext<'a>>> {
67
let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate_on_groups(df, groups, state);
68
if self.allow_threading {
69
POOL.install(|| self.inputs.par_iter().map(f).collect())
70
} else {
71
self.inputs.iter().map(f).collect()
72
}
73
}
74
75
fn finish_apply_groups<'a>(
76
&self,
77
mut ac: AggregationContext<'a>,
78
ca: ListChunked,
79
) -> PolarsResult<AggregationContext<'a>> {
80
let c = if self.flags.returns_scalar() {
81
let out = ca.explode(false).unwrap();
82
// if the explode doesn't return the same len, it wasn't scalar.
83
polars_ensure!(out.len() == ca.len(), InvalidOperation: "expected scalar for expr: {}, got {}", self.expr, &out);
84
ac.update_groups = UpdateGroups::No;
85
out.into_column()
86
} else {
87
ac.with_update_groups(UpdateGroups::WithSeriesLen);
88
ca.into_series().into()
89
};
90
91
ac.with_values_and_args(c, true, None, false, self.flags.returns_scalar())?;
92
93
Ok(ac)
94
}
95
96
fn get_input_schema(&self, _df: &DataFrame) -> Cow<'_, Schema> {
97
Cow::Borrowed(self.input_schema.as_ref())
98
}
99
100
/// Evaluates and flattens `Option<Column>` to `Column`.
101
fn eval_and_flatten(&self, inputs: &mut [Column]) -> PolarsResult<Column> {
102
self.function.call_udf(inputs)
103
}
104
fn apply_single_group_aware<'a>(
105
&self,
106
mut ac: AggregationContext<'a>,
107
) -> PolarsResult<AggregationContext<'a>> {
108
let s = ac.get_values();
109
110
#[allow(clippy::nonminimal_bool)]
111
{
112
polars_ensure!(
113
!(matches!(ac.agg_state(), AggState::AggregatedScalar(_)) && !s.dtype().is_list() ) ,
114
expr = self.expr,
115
ComputeError: "cannot aggregate, the column is already aggregated",
116
);
117
}
118
119
let name = s.name().clone();
120
let agg = ac.aggregated();
121
// Collection of empty list leads to a null dtype. See: #3687.
122
if agg.is_empty() {
123
// Create input for the function to determine the output dtype, see #3946.
124
let agg = agg.list().unwrap();
125
let input_dtype = agg.inner_dtype();
126
let input = Column::full_null(name.clone(), 0, input_dtype);
127
128
let output = self.eval_and_flatten(&mut [input])?;
129
let ca = ListChunked::full(name, output.as_materialized_series(), 0);
130
return self.finish_apply_groups(ac, ca);
131
}
132
133
let f = |opt_s: Option<Series>| match opt_s {
134
None => Ok(None),
135
Some(mut s) => {
136
if self.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY) {
137
s.rename(name.clone());
138
}
139
Ok(Some(
140
self.function
141
.call_udf(&mut [Column::from(s)])?
142
.take_materialized_series(),
143
))
144
},
145
};
146
147
let ca: ListChunked = if self.allow_threading {
148
let dtype = if self.output_field.dtype.is_known() && !self.output_field.dtype.is_null()
149
{
150
Some(self.output_field.dtype.clone())
151
} else {
152
None
153
};
154
155
let lst = agg.list().unwrap();
156
let iter = lst.par_iter().map(f);
157
158
if let Some(dtype) = dtype {
159
// @NOTE: Since the output type for scalars does an implicit explode, we need to
160
// patch up the type here to also be a list.
161
let out_dtype = if self.is_scalar() {
162
DataType::List(Box::new(dtype))
163
} else {
164
dtype
165
};
166
167
let out: ListChunked = POOL.install(|| {
168
iter.collect_ca_with_dtype::<PolarsResult<_>>(PlSmallStr::EMPTY, out_dtype)
169
})?;
170
out
171
} else {
172
POOL.install(|| try_list_from_par_iter(iter, PlSmallStr::EMPTY))?
173
}
174
} else {
175
agg.list()
176
.unwrap()
177
.into_iter()
178
.map(f)
179
.collect::<PolarsResult<_>>()?
180
};
181
182
self.finish_apply_groups(ac, ca.with_name(name))
183
}
184
185
/// Apply elementwise e.g. ignore the group/list indices.
186
fn apply_single_elementwise<'a>(
187
&self,
188
mut ac: AggregationContext<'a>,
189
) -> PolarsResult<AggregationContext<'a>> {
190
let (c, aggregated) = match ac.agg_state() {
191
AggState::AggregatedList(c) => {
192
let ca = c.list().unwrap();
193
let out = ca.apply_to_inner(&|s| {
194
Ok(self
195
.eval_and_flatten(&mut [s.into_column()])?
196
.take_materialized_series())
197
})?;
198
(out.into_column(), true)
199
},
200
AggState::NotAggregated(c) => {
201
let (out, aggregated) = (self.eval_and_flatten(&mut [c.clone()])?, false);
202
check_map_output_len(c.len(), out.len(), &self.expr)?;
203
(out, aggregated)
204
},
205
agg_state => {
206
ac.with_agg_state(agg_state.try_map(|s| self.eval_and_flatten(&mut [s.clone()]))?);
207
return Ok(ac);
208
},
209
};
210
211
ac.with_values_and_args(c, aggregated, Some(&self.expr), true, self.is_scalar())?;
212
Ok(ac)
213
}
214
fn apply_multiple_group_aware<'a>(
215
&self,
216
mut acs: Vec<AggregationContext<'a>>,
217
df: &DataFrame,
218
) -> PolarsResult<AggregationContext<'a>> {
219
let mut container = vec![Default::default(); acs.len()];
220
let schema = self.get_input_schema(df);
221
let field = self.to_field(&schema)?;
222
223
// Aggregate representation of the aggregation contexts,
224
// then unpack the lists and finally create iterators from this list chunked arrays.
225
let mut iters = acs
226
.iter_mut()
227
.map(|ac| ac.iter_groups(self.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY)))
228
.collect::<Vec<_>>();
229
230
// Length of the items to iterate over.
231
let len = iters[0].size_hint().0;
232
233
let ca = if len == 0 {
234
let mut builder = get_list_builder(&field.dtype, len * 5, len, field.name);
235
for _ in 0..len {
236
container.clear();
237
for iter in &mut iters {
238
match iter.next().unwrap() {
239
None => {
240
builder.append_null();
241
},
242
Some(s) => container.push(s.deep_clone().into()),
243
}
244
}
245
let out = self
246
.function
247
.call_udf(&mut container)
248
.map(|c| c.take_materialized_series())?;
249
250
builder.append_series(&out)?
251
}
252
builder.finish()
253
} else {
254
// We still need this branch to materialize unknown/ data dependent types in eager. :(
255
(0..len)
256
.map(|_| {
257
container.clear();
258
for iter in &mut iters {
259
match iter.next().unwrap() {
260
None => return Ok(None),
261
Some(s) => container.push(s.deep_clone().into()),
262
}
263
}
264
Ok(Some(
265
self.function
266
.call_udf(&mut container)?
267
.take_materialized_series(),
268
))
269
})
270
.collect::<PolarsResult<ListChunked>>()?
271
.with_name(field.name.clone())
272
};
273
#[cfg(debug_assertions)]
274
{
275
let inner = ca.dtype().inner_dtype().unwrap();
276
if field.dtype.is_known() {
277
assert_eq!(inner, &field.dtype);
278
}
279
}
280
281
drop(iters);
282
283
// Take the first aggregation context that as that is the input series.
284
let ac = acs.swap_remove(0);
285
self.finish_apply_groups(ac, ca)
286
}
287
}
288
289
fn check_map_output_len(input_len: usize, output_len: usize, expr: &Expr) -> PolarsResult<()> {
290
polars_ensure!(
291
input_len == output_len, expr = expr, InvalidOperation:
292
"output length of `map` ({}) must be equal to the input length ({}); \
293
consider using `apply` instead", output_len, input_len
294
);
295
Ok(())
296
}
297
298
impl PhysicalExpr for ApplyExpr {
299
fn as_expression(&self) -> Option<&Expr> {
300
Some(&self.expr)
301
}
302
303
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
304
let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate(df, state);
305
let mut inputs = if self.allow_threading && self.inputs.len() > 1 {
306
POOL.install(|| {
307
self.inputs
308
.par_iter()
309
.map(f)
310
.collect::<PolarsResult<Vec<_>>>()
311
})
312
} else {
313
self.inputs.iter().map(f).collect::<PolarsResult<Vec<_>>>()
314
}?;
315
316
if self.flags.contains(FunctionFlags::ALLOW_RENAME) {
317
self.eval_and_flatten(&mut inputs)
318
} else {
319
let in_name = inputs[0].name().clone();
320
Ok(self.eval_and_flatten(&mut inputs)?.with_name(in_name))
321
}
322
}
323
324
#[allow(clippy::ptr_arg)]
325
fn evaluate_on_groups<'a>(
326
&self,
327
df: &DataFrame,
328
groups: &'a GroupPositions,
329
state: &ExecutionState,
330
) -> PolarsResult<AggregationContext<'a>> {
331
if self.inputs.len() == 1 {
332
let ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;
333
334
match self.flags.is_elementwise() {
335
false => self.apply_single_group_aware(ac),
336
true => self.apply_single_elementwise(ac),
337
}
338
} else {
339
let acs = self.prepare_multiple_inputs(df, groups, state)?;
340
341
match self.flags.is_elementwise() {
342
false => self.apply_multiple_group_aware(acs, df),
343
true => {
344
let mut has_agg_list = false;
345
let mut has_agg_scalar = false;
346
let mut has_not_agg = false;
347
for ac in &acs {
348
match ac.state {
349
AggState::AggregatedList(_) => has_agg_list = true,
350
AggState::AggregatedScalar(_) => has_agg_scalar = true,
351
AggState::NotAggregated(_) => has_not_agg = true,
352
_ => {},
353
}
354
}
355
if has_agg_list || (has_agg_scalar && has_not_agg) {
356
self.apply_multiple_group_aware(acs, df)
357
} else {
358
apply_multiple_elementwise(
359
acs,
360
self.function.as_ref(),
361
&self.expr,
362
self.check_lengths,
363
self.is_scalar(),
364
)
365
}
366
},
367
}
368
}
369
}
370
371
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
372
self.expr.to_field(input_schema)
373
}
374
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
375
if self.inputs.len() == 1 && self.flags.is_elementwise() {
376
Some(self)
377
} else {
378
None
379
}
380
}
381
fn is_scalar(&self) -> bool {
382
self.flags.returns_scalar()
383
|| (self.function_operates_on_scalar && self.flags.is_length_preserving())
384
}
385
}
386
387
fn apply_multiple_elementwise<'a>(
388
mut acs: Vec<AggregationContext<'a>>,
389
function: &dyn ColumnsUdf,
390
expr: &Expr,
391
check_lengths: bool,
392
returns_scalar: bool,
393
) -> PolarsResult<AggregationContext<'a>> {
394
match acs.first().unwrap().agg_state() {
395
// A fast path that doesn't drop groups of the first arg.
396
// This doesn't require group re-computation.
397
AggState::AggregatedList(s) => {
398
let ca = s.list().unwrap();
399
400
let other = acs[1..]
401
.iter()
402
.map(|ac| ac.flat_naive().into_owned())
403
.collect::<Vec<_>>();
404
405
let out = ca.apply_to_inner(&|s| {
406
let mut args = Vec::with_capacity(other.len() + 1);
407
args.push(s.into());
408
args.extend_from_slice(&other);
409
Ok(function
410
.call_udf(&mut args)?
411
.as_materialized_series()
412
.clone())
413
})?;
414
let mut ac = acs.swap_remove(0);
415
ac.with_values(out.into_column(), true, None)?;
416
Ok(ac)
417
},
418
first_as => {
419
let check_lengths = check_lengths && !matches!(first_as, AggState::LiteralScalar(_));
420
let aggregated = acs.iter().all(|ac| ac.is_aggregated() | ac.is_literal())
421
&& acs.iter().any(|ac| ac.is_aggregated());
422
let mut c = acs
423
.iter_mut()
424
.enumerate()
425
.map(|(i, ac)| {
426
// Make sure the groups are updated because we are about to throw away
427
// the series length information, only on the first iteration.
428
if let (0, UpdateGroups::WithSeriesLen) = (i, &ac.update_groups) {
429
ac.groups();
430
}
431
432
ac.flat_naive().into_owned()
433
})
434
.collect::<Vec<_>>();
435
436
let input_len = c[0].len();
437
let c = function.call_udf(&mut c)?;
438
if check_lengths {
439
check_map_output_len(input_len, c.len(), expr)?;
440
}
441
442
// Take the first aggregation context that as that is the input series.
443
let mut ac = acs.swap_remove(0);
444
ac.with_values_and_args(c, aggregated, None, true, returns_scalar)?;
445
Ok(ac)
446
},
447
}
448
}
449
450
impl PartitionedAggregation for ApplyExpr {
451
fn evaluate_partitioned(
452
&self,
453
df: &DataFrame,
454
groups: &GroupPositions,
455
state: &ExecutionState,
456
) -> PolarsResult<Column> {
457
let a = self.inputs[0].as_partitioned_aggregator().unwrap();
458
let s = a.evaluate_partitioned(df, groups, state)?;
459
460
if self.flags.contains(FunctionFlags::ALLOW_RENAME) {
461
self.eval_and_flatten(&mut [s])
462
} else {
463
let in_name = s.name().clone();
464
Ok(self.eval_and_flatten(&mut [s])?.with_name(in_name))
465
}
466
}
467
468
fn finalize(
469
&self,
470
partitioned: Column,
471
_groups: &GroupPositions,
472
_state: &ExecutionState,
473
) -> PolarsResult<Column> {
474
Ok(partitioned)
475
}
476
}
477
478