Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/aggregation.rs
8415 views
1
use std::borrow::Cow;
2
3
use arrow::legacy::utils::CustomIterTools;
4
use polars_compute::rolling::QuantileMethod;
5
use polars_core::POOL;
6
use polars_core::prelude::*;
7
use polars_core::series::IsSorted;
8
use polars_core::utils::{_split_offsets, NoNull};
9
use polars_ops::prelude::ArgAgg;
10
#[cfg(feature = "propagate_nans")]
11
use polars_ops::prelude::nan_propagating_aggregate;
12
use polars_utils::itertools::Itertools;
13
use rayon::prelude::*;
14
15
use super::*;
16
use crate::expressions::AggState::AggregatedScalar;
17
use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
18
use crate::reduce::GroupedReduction;
19
20
#[derive(Debug, Clone, Copy)]
21
pub struct AggregationType {
22
pub(crate) groupby: GroupByMethod,
23
pub(crate) allow_threading: bool,
24
}
25
26
pub(crate) struct AggregationExpr {
27
pub(crate) input: Arc<dyn PhysicalExpr>,
28
pub(crate) agg_type: AggregationType,
29
pub(crate) output_field: Field,
30
}
31
32
impl AggregationExpr {
33
pub fn new(
34
expr: Arc<dyn PhysicalExpr>,
35
agg_type: AggregationType,
36
output_field: Field,
37
) -> Self {
38
Self {
39
input: expr,
40
agg_type,
41
output_field,
42
}
43
}
44
}
45
46
impl PhysicalExpr for AggregationExpr {
47
fn as_expression(&self) -> Option<&Expr> {
48
None
49
}
50
51
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
52
let s = self.input.evaluate(df, state)?;
53
54
let AggregationType {
55
groupby,
56
allow_threading,
57
} = self.agg_type;
58
59
let is_float = s.dtype().is_float();
60
let group_by = match groupby {
61
GroupByMethod::NanMin if !is_float => GroupByMethod::Min,
62
GroupByMethod::NanMax if !is_float => GroupByMethod::Max,
63
gb => gb,
64
};
65
66
match group_by {
67
GroupByMethod::Min => match s.is_sorted_flag() {
68
IsSorted::Ascending | IsSorted::Descending => {
69
s.min_reduce().map(|sc| sc.into_column(s.name().clone()))
70
},
71
IsSorted::Not => parallel_op_columns(
72
|s| s.min_reduce().map(|sc| sc.into_column(s.name().clone())),
73
s,
74
allow_threading,
75
),
76
},
77
#[cfg(feature = "propagate_nans")]
78
GroupByMethod::NanMin => parallel_op_columns(
79
|s| {
80
Ok(polars_ops::prelude::nan_propagating_aggregate::nan_min_s(
81
s.as_materialized_series(),
82
s.name().clone(),
83
)
84
.into_column())
85
},
86
s,
87
allow_threading,
88
),
89
#[cfg(not(feature = "propagate_nans"))]
90
GroupByMethod::NanMin => {
91
panic!("activate 'propagate_nans' feature")
92
},
93
GroupByMethod::Max => match s.is_sorted_flag() {
94
IsSorted::Ascending | IsSorted::Descending => {
95
s.max_reduce().map(|sc| sc.into_column(s.name().clone()))
96
},
97
IsSorted::Not => parallel_op_columns(
98
|s| s.max_reduce().map(|sc| sc.into_column(s.name().clone())),
99
s,
100
allow_threading,
101
),
102
},
103
#[cfg(feature = "propagate_nans")]
104
GroupByMethod::NanMax => parallel_op_columns(
105
|s| {
106
Ok(polars_ops::prelude::nan_propagating_aggregate::nan_max_s(
107
s.as_materialized_series(),
108
s.name().clone(),
109
)
110
.into_column())
111
},
112
s,
113
allow_threading,
114
),
115
#[cfg(not(feature = "propagate_nans"))]
116
GroupByMethod::NanMax => {
117
panic!("activate 'propagate_nans' feature")
118
},
119
GroupByMethod::Median => s.median_reduce().map(|sc| sc.into_column(s.name().clone())),
120
GroupByMethod::Mean => s.mean_reduce().map(|sc| sc.into_column(s.name().clone())),
121
GroupByMethod::First => Ok(if s.is_empty() {
122
Column::full_null(s.name().clone(), 1, s.dtype())
123
} else {
124
s.head(Some(1))
125
}),
126
GroupByMethod::FirstNonNull => Ok(s
127
.as_materialized_series_maintain_scalar()
128
.first_non_null()
129
.into_column(s.name().clone())),
130
GroupByMethod::Last => Ok(if s.is_empty() {
131
Column::full_null(s.name().clone(), 1, s.dtype())
132
} else {
133
s.tail(Some(1))
134
}),
135
GroupByMethod::LastNonNull => Ok(s
136
.as_materialized_series_maintain_scalar()
137
.last_non_null()
138
.into_column(s.name().clone())),
139
GroupByMethod::Item { allow_empty } => Ok(match s.len() {
140
0 if allow_empty => Column::full_null(s.name().clone(), 1, s.dtype()),
141
1 => s,
142
n => polars_bail!(item_agg_count_not_one = n, allow_empty = allow_empty),
143
}),
144
GroupByMethod::Sum => parallel_op_columns(
145
|s| s.sum_reduce().map(|sc| sc.into_column(s.name().clone())),
146
s,
147
allow_threading,
148
),
149
GroupByMethod::Groups => unreachable!(),
150
GroupByMethod::NUnique => s.n_unique().map(|count| {
151
IdxCa::from_slice(s.name().clone(), &[count as IdxSize]).into_column()
152
}),
153
GroupByMethod::Count { include_nulls } => {
154
let count = s.len() - s.null_count() * !include_nulls as usize;
155
156
Ok(IdxCa::from_slice(s.name().clone(), &[count as IdxSize]).into_column())
157
},
158
GroupByMethod::Implode => s.implode().map(|ca| ca.into_column()),
159
GroupByMethod::Std(ddof) => s
160
.std_reduce(ddof)
161
.map(|sc| sc.into_column(s.name().clone())),
162
GroupByMethod::Var(ddof) => s
163
.var_reduce(ddof)
164
.map(|sc| sc.into_column(s.name().clone())),
165
GroupByMethod::Quantile(_, _) => unimplemented!(),
166
GroupByMethod::ArgMin => {
167
let opt = s.as_materialized_series().arg_min();
168
Ok(opt.map_or_else(
169
|| Column::full_null(s.name().clone(), 1, &IDX_DTYPE),
170
|idx| {
171
Column::new_scalar(
172
s.name().clone(),
173
Scalar::new_idxsize(idx.try_into().unwrap()),
174
1,
175
)
176
},
177
))
178
},
179
GroupByMethod::ArgMax => {
180
let opt = s.as_materialized_series().arg_max();
181
Ok(opt.map_or_else(
182
|| Column::full_null(s.name().clone(), 1, &IDX_DTYPE),
183
|idx| {
184
Column::new_scalar(
185
s.name().clone(),
186
Scalar::new_idxsize(idx.try_into().unwrap()),
187
1,
188
)
189
},
190
))
191
},
192
}
193
}
194
195
#[allow(clippy::ptr_arg)]
196
fn evaluate_on_groups<'a>(
197
&self,
198
df: &DataFrame,
199
groups: &'a GroupPositions,
200
state: &ExecutionState,
201
) -> PolarsResult<AggregationContext<'a>> {
202
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
203
204
// don't change names by aggregations as is done in polars-core
205
let keep_name = ac.get_values().name().clone();
206
207
if let AggState::LiteralScalar(c) = &mut ac.state {
208
*c = self.evaluate(df, state)?;
209
return Ok(ac);
210
}
211
212
// AggregatedScalar has no defined group structure. We fix it up here, so that we can
213
// reliably call `agg_*` functions with the groups.
214
ac.set_groups_for_undefined_agg_states();
215
216
// SAFETY:
217
// groups must always be in bounds.
218
let out = unsafe {
219
match self.agg_type.groupby {
220
GroupByMethod::Min => {
221
let (c, groups) = ac.get_final_aggregation();
222
let agg_c = c.agg_min(&groups);
223
AggregatedScalar(agg_c.with_name(keep_name))
224
},
225
GroupByMethod::Max => {
226
let (c, groups) = ac.get_final_aggregation();
227
let agg_c = c.agg_max(&groups);
228
AggregatedScalar(agg_c.with_name(keep_name))
229
},
230
GroupByMethod::ArgMin => {
231
let (c, groups) = ac.get_final_aggregation();
232
let agg_c = c.agg_arg_min(&groups);
233
AggregatedScalar(agg_c.with_name(keep_name))
234
},
235
GroupByMethod::ArgMax => {
236
let (c, groups) = ac.get_final_aggregation();
237
let agg_c = c.agg_arg_max(&groups);
238
AggregatedScalar(agg_c.with_name(keep_name))
239
},
240
GroupByMethod::Median => {
241
let (c, groups) = ac.get_final_aggregation();
242
let agg_c = c.agg_median(&groups);
243
AggregatedScalar(agg_c.with_name(keep_name))
244
},
245
GroupByMethod::Mean => {
246
let (c, groups) = ac.get_final_aggregation();
247
let agg_c = c.agg_mean(&groups);
248
AggregatedScalar(agg_c.with_name(keep_name))
249
},
250
GroupByMethod::Sum => {
251
let (c, groups) = ac.get_final_aggregation();
252
let agg_c = c.agg_sum(&groups);
253
AggregatedScalar(agg_c.with_name(keep_name))
254
},
255
GroupByMethod::Count { include_nulls } => {
256
if include_nulls || ac.get_values().null_count() == 0 {
257
// a few fast paths that prevent materializing new groups
258
match ac.update_groups {
259
UpdateGroups::WithSeriesLen => {
260
let list = ac
261
.get_values()
262
.list()
263
.expect("impl error, should be a list at this point");
264
265
let mut s = match list.chunks().len() {
266
1 => {
267
let arr = list.downcast_iter().next().unwrap();
268
let offsets = arr.offsets().as_slice();
269
270
let mut previous = 0i64;
271
let counts: NoNull<IdxCa> = offsets[1..]
272
.iter()
273
.map(|&o| {
274
let len = (o - previous) as IdxSize;
275
previous = o;
276
len
277
})
278
.collect_trusted();
279
counts.into_inner()
280
},
281
_ => {
282
let counts: NoNull<IdxCa> = list
283
.amortized_iter()
284
.map(|s| {
285
if let Some(s) = s {
286
s.as_ref().len() as IdxSize
287
} else {
288
1
289
}
290
})
291
.collect_trusted();
292
counts.into_inner()
293
},
294
};
295
s.rename(keep_name);
296
AggregatedScalar(s.into_column())
297
},
298
UpdateGroups::WithGroupsLen => {
299
// no need to update the groups
300
// we can just get the attribute, because we only need the length,
301
// not the correct order
302
let mut ca = ac.groups.group_count();
303
ca.rename(keep_name);
304
AggregatedScalar(ca.into_column())
305
},
306
// materialize groups
307
_ => {
308
let mut ca = ac.groups().group_count();
309
ca.rename(keep_name);
310
AggregatedScalar(ca.into_column())
311
},
312
}
313
} else {
314
// TODO: optimize this/and write somewhere else.
315
match ac.agg_state() {
316
AggState::LiteralScalar(_) => unreachable!(),
317
AggState::AggregatedScalar(c) => AggregatedScalar(
318
c.is_not_null().cast(&IDX_DTYPE).unwrap().into_column(),
319
),
320
AggState::AggregatedList(s) => {
321
let ca = s.list()?;
322
let out: IdxCa = ca
323
.into_iter()
324
.map(|opt_s| {
325
opt_s
326
.map(|s| s.len() as IdxSize - s.null_count() as IdxSize)
327
})
328
.collect();
329
AggregatedScalar(out.into_column().with_name(keep_name))
330
},
331
AggState::NotAggregated(s) => {
332
let s = s.clone();
333
let groups = ac.groups();
334
let out: IdxCa = if matches!(s.dtype(), &DataType::Null) {
335
IdxCa::full(s.name().clone(), 0, groups.len())
336
} else {
337
match groups.as_ref().as_ref() {
338
GroupsType::Idx(idx) => {
339
let s = s.rechunk();
340
// @scalar-opt
341
// @partition-opt
342
let array = &s.as_materialized_series().chunks()[0];
343
let validity = array.validity().unwrap();
344
idx.iter()
345
.map(|(_, g)| {
346
let mut count = 0 as IdxSize;
347
// Count valid values
348
g.iter().for_each(|i| {
349
count += validity
350
.get_bit_unchecked(*i as usize)
351
as IdxSize;
352
});
353
count
354
})
355
.collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE)
356
},
357
GroupsType::Slice { groups, .. } => {
358
// Slice and use computed null count
359
groups
360
.iter()
361
.map(|g| {
362
let start = g[0];
363
let len = g[1];
364
len - s
365
.slice(start as i64, len as usize)
366
.null_count()
367
as IdxSize
368
})
369
.collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE)
370
},
371
}
372
};
373
AggregatedScalar(out.into_column())
374
},
375
}
376
}
377
},
378
GroupByMethod::First => {
379
let (s, groups) = ac.get_final_aggregation();
380
let agg_s = s.agg_first(&groups);
381
AggregatedScalar(agg_s.with_name(keep_name))
382
},
383
GroupByMethod::FirstNonNull => {
384
let (s, groups) = ac.get_final_aggregation();
385
let agg_s = s.agg_first_non_null(&groups);
386
AggregatedScalar(agg_s.with_name(keep_name))
387
},
388
GroupByMethod::Last => {
389
let (s, groups) = ac.get_final_aggregation();
390
let agg_s = s.agg_last(&groups);
391
AggregatedScalar(agg_s.with_name(keep_name))
392
},
393
GroupByMethod::LastNonNull => {
394
let (s, groups) = ac.get_final_aggregation();
395
let agg_s = s.agg_last_non_null(&groups);
396
AggregatedScalar(agg_s.with_name(keep_name))
397
},
398
GroupByMethod::Item { allow_empty } => {
399
let (s, groups) = ac.get_final_aggregation();
400
for gc in groups.group_count().iter() {
401
match gc {
402
Some(0) if allow_empty => continue,
403
None | Some(1) => continue,
404
Some(n) => {
405
polars_bail!(item_agg_count_not_one = n, allow_empty = allow_empty);
406
},
407
}
408
}
409
let agg_s = s.agg_first(&groups);
410
AggregatedScalar(agg_s.with_name(keep_name))
411
},
412
GroupByMethod::NUnique => {
413
let (s, groups) = ac.get_final_aggregation();
414
let agg_s = s.agg_n_unique(&groups);
415
AggregatedScalar(agg_s.with_name(keep_name))
416
},
417
GroupByMethod::Implode => AggregatedScalar(match ac.agg_state() {
418
AggState::LiteralScalar(_) => unreachable!(), // handled above
419
AggState::AggregatedScalar(c) => c.as_list().into_column(),
420
AggState::NotAggregated(_) | AggState::AggregatedList(_) => ac.aggregated(),
421
}),
422
GroupByMethod::Groups => {
423
let mut column: ListChunked = ac.groups().as_list_chunked();
424
column.rename(keep_name);
425
AggregatedScalar(column.into_column())
426
},
427
GroupByMethod::Std(ddof) => {
428
let (c, groups) = ac.get_final_aggregation();
429
let agg_c = c.agg_std(&groups, ddof);
430
AggregatedScalar(agg_c.with_name(keep_name))
431
},
432
GroupByMethod::Var(ddof) => {
433
let (c, groups) = ac.get_final_aggregation();
434
let agg_c = c.agg_var(&groups, ddof);
435
AggregatedScalar(agg_c.with_name(keep_name))
436
},
437
GroupByMethod::Quantile(_, _) => {
438
// implemented explicitly in AggQuantile struct
439
unimplemented!()
440
},
441
GroupByMethod::NanMin => {
442
#[cfg(feature = "propagate_nans")]
443
{
444
let (c, groups) = ac.get_final_aggregation();
445
let agg_c = if c.dtype().is_float() {
446
nan_propagating_aggregate::group_agg_nan_min_s(
447
c.as_materialized_series(),
448
&groups,
449
)
450
.into_column()
451
} else {
452
c.agg_min(&groups)
453
};
454
AggregatedScalar(agg_c.with_name(keep_name))
455
}
456
#[cfg(not(feature = "propagate_nans"))]
457
{
458
panic!("activate 'propagate_nans' feature")
459
}
460
},
461
GroupByMethod::NanMax => {
462
#[cfg(feature = "propagate_nans")]
463
{
464
let (c, groups) = ac.get_final_aggregation();
465
let agg_c = if c.dtype().is_float() {
466
nan_propagating_aggregate::group_agg_nan_max_s(
467
c.as_materialized_series(),
468
&groups,
469
)
470
.into_column()
471
} else {
472
c.agg_max(&groups)
473
};
474
AggregatedScalar(agg_c.with_name(keep_name))
475
}
476
#[cfg(not(feature = "propagate_nans"))]
477
{
478
panic!("activate 'propagate_nans' feature")
479
}
480
},
481
}
482
};
483
484
Ok(AggregationContext::from_agg_state(
485
out,
486
Cow::Borrowed(groups),
487
))
488
}
489
490
fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
491
Ok(self.output_field.clone())
492
}
493
494
fn is_scalar(&self) -> bool {
495
true
496
}
497
}
498
499
pub struct AggQuantileExpr {
500
input: Arc<dyn PhysicalExpr>,
501
quantile: Arc<dyn PhysicalExpr>,
502
method: QuantileMethod,
503
}
504
505
impl AggQuantileExpr {
506
pub fn new(
507
input: Arc<dyn PhysicalExpr>,
508
quantile: Arc<dyn PhysicalExpr>,
509
method: QuantileMethod,
510
) -> Self {
511
Self {
512
input,
513
quantile,
514
method,
515
}
516
}
517
}
518
519
impl PhysicalExpr for AggQuantileExpr {
520
fn as_expression(&self) -> Option<&Expr> {
521
None
522
}
523
524
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
525
let input = self.input.evaluate(df, state)?;
526
527
let quantile = self.quantile.evaluate(df, state)?;
528
529
polars_ensure!(quantile.len() <= 1, ComputeError:
530
"polars does not support varying quantiles yet, \
531
make sure the 'quantile' expression input produces a single quantile or a list of quantiles"
532
);
533
534
let s = quantile.as_materialized_series();
535
536
match s.dtype() {
537
DataType::List(_) => {
538
let list = s.list()?;
539
let inner_s = list.get_as_series(0).unwrap();
540
if inner_s.has_nulls() {
541
polars_bail!(ComputeError: "quantile expression contains null values");
542
}
543
544
let v: Vec<f64> = inner_s
545
.cast(&DataType::Float64)?
546
.f64()?
547
.into_no_null_iter()
548
.collect();
549
550
input
551
.quantiles_reduce(&v, self.method)
552
.map(|sc| sc.into_column(input.name().clone()))
553
},
554
_ => {
555
let q: f64 = quantile.get(0).unwrap().try_extract()?;
556
input
557
.quantile_reduce(q, self.method)
558
.map(|sc| sc.into_column(input.name().clone()))
559
},
560
}
561
}
562
563
#[allow(clippy::ptr_arg)]
564
fn evaluate_on_groups<'a>(
565
&self,
566
df: &DataFrame,
567
groups: &'a GroupPositions,
568
state: &ExecutionState,
569
) -> PolarsResult<AggregationContext<'a>> {
570
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
571
572
// AggregatedScalar has no defined group structure. We fix it up here, so that we can
573
// reliably call `agg_quantile` functions with the groups.
574
ac.set_groups_for_undefined_agg_states();
575
576
// don't change names by aggregations as is done in polars-core
577
let keep_name = ac.get_values().name().clone();
578
579
let quantile_column = self.quantile.evaluate(df, state)?;
580
polars_ensure!(quantile_column.len() <= 1, ComputeError:
581
"polars only supports computing a single quantile in a groupby aggregation context"
582
);
583
let quantile: f64 = quantile_column.get(0).unwrap().try_extract()?;
584
585
if let AggState::LiteralScalar(c) = &mut ac.state {
586
*c = c
587
.quantile_reduce(quantile, self.method)?
588
.into_column(keep_name);
589
return Ok(ac);
590
}
591
592
// SAFETY:
593
// groups are in bounds
594
let mut agg = unsafe {
595
ac.flat_naive()
596
.into_owned()
597
.agg_quantile(ac.groups(), quantile, self.method)
598
};
599
agg.rename(keep_name);
600
Ok(AggregationContext::from_agg_state(
601
AggregatedScalar(agg),
602
Cow::Borrowed(groups),
603
))
604
}
605
606
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
607
// If the quantile expression is a literal that yields a list of floats,
608
// the aggregation returns a list of quantiles (one list per row/group).
609
// In that case, report `List(Float64)` as the output field.
610
let input_field = self.input.to_field(input_schema)?;
611
match self.quantile.to_field(input_schema) {
612
Ok(qf) => match qf.dtype() {
613
DataType::List(inner) => {
614
if inner.is_float() {
615
Ok(Field::new(
616
input_field.name().clone(),
617
DataType::List(Box::new(DataType::Float64)),
618
))
619
} else {
620
// fallback to input field
621
Ok(input_field)
622
}
623
},
624
_ => Ok(input_field),
625
},
626
Err(_) => Ok(input_field),
627
}
628
}
629
630
fn is_scalar(&self) -> bool {
631
true
632
}
633
}
634
635
pub struct AggMinMaxByExpr {
636
input: Arc<dyn PhysicalExpr>,
637
by: Arc<dyn PhysicalExpr>,
638
is_max_by: bool,
639
}
640
641
impl AggMinMaxByExpr {
642
pub fn new_min_by(input: Arc<dyn PhysicalExpr>, by: Arc<dyn PhysicalExpr>) -> Self {
643
Self {
644
input,
645
by,
646
is_max_by: false,
647
}
648
}
649
650
pub fn new_max_by(input: Arc<dyn PhysicalExpr>, by: Arc<dyn PhysicalExpr>) -> Self {
651
Self {
652
input,
653
by,
654
is_max_by: true,
655
}
656
}
657
}
658
659
impl PhysicalExpr for AggMinMaxByExpr {
660
fn as_expression(&self) -> Option<&Expr> {
661
None
662
}
663
664
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
665
let input = self.input.evaluate(df, state)?;
666
let by = self.by.evaluate(df, state)?;
667
let name = if self.is_max_by { "max_by" } else { "min_by" };
668
polars_ensure!(
669
input.len() == by.len(),
670
ShapeMismatch: "'by' column in {} expression has incorrect length: expected {}, got {}",
671
name, input.len(), by.len()
672
);
673
let arg_extremum = if self.is_max_by {
674
by.as_materialized_series_maintain_scalar().arg_max()
675
} else {
676
by.as_materialized_series_maintain_scalar().arg_min()
677
};
678
let out = if let Some(idx) = arg_extremum {
679
input.slice(idx as i64, 1)
680
} else {
681
let dtype = input.dtype().clone();
682
Column::new_scalar(input.name().clone(), Scalar::null(dtype), 1)
683
};
684
Ok(out)
685
}
686
687
#[allow(clippy::ptr_arg)]
688
fn evaluate_on_groups<'a>(
689
&self,
690
df: &DataFrame,
691
groups: &'a GroupPositions,
692
state: &ExecutionState,
693
) -> PolarsResult<AggregationContext<'a>> {
694
let ac = self.input.evaluate_on_groups(df, groups, state)?;
695
let ac_by = self.by.evaluate_on_groups(df, groups, state)?;
696
assert!(ac.groups.len() == ac_by.groups.len());
697
698
// Don't change names by aggregations as is done in polars-core
699
let keep_name = ac.get_values().name().clone();
700
701
let (input_col, input_groups) = ac.get_final_aggregation();
702
let (by_col, by_groups) = ac_by.get_final_aggregation();
703
GroupsType::check_lengths(&input_groups, &by_groups)?;
704
705
// Dispatch to arg_min/arg_max and then gather
706
// SAFETY: Groups are correct.
707
let idxs_in_groups = if self.is_max_by {
708
unsafe { by_col.agg_arg_max(&by_groups) }
709
} else {
710
unsafe { by_col.agg_arg_min(&by_groups) }
711
};
712
let idxs_in_groups: &IdxCa = idxs_in_groups.as_materialized_series().as_ref().as_ref();
713
let flat_gather_idxs = match input_groups.as_ref().as_ref() {
714
GroupsType::Idx(g) => idxs_in_groups
715
.into_no_null_iter()
716
.enumerate()
717
.map(|(group_idx, idx_in_group)| g.all()[group_idx][idx_in_group as usize])
718
.collect_vec(),
719
GroupsType::Slice { groups, .. } => idxs_in_groups
720
.into_no_null_iter()
721
.enumerate()
722
.map(|(group_idx, idx_in_group)| groups[group_idx][0] + idx_in_group)
723
.collect_vec(),
724
};
725
726
// SAFETY: All indices are within input_col's groups.
727
let gathered = unsafe { input_col.take_slice_unchecked(&flat_gather_idxs) };
728
let agg_state = AggregatedScalar(gathered.with_name(keep_name));
729
Ok(AggregationContext::from_agg_state(
730
agg_state,
731
Cow::Borrowed(groups),
732
))
733
}
734
735
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
736
self.input.to_field(input_schema)
737
}
738
739
fn is_scalar(&self) -> bool {
740
true
741
}
742
}
743
744
pub(crate) struct AnonymousAggregationExpr {
745
pub(crate) inputs: Vec<Arc<dyn PhysicalExpr>>,
746
pub(crate) grouped_reduction: Box<dyn GroupedReduction>,
747
pub(crate) output_field: Field,
748
}
749
750
impl AnonymousAggregationExpr {
751
pub fn new(
752
inputs: Vec<Arc<dyn PhysicalExpr>>,
753
grouped_reduction: Box<dyn GroupedReduction>,
754
output_field: Field,
755
) -> Self {
756
Self {
757
inputs,
758
grouped_reduction,
759
output_field,
760
}
761
}
762
}
763
764
impl PhysicalExpr for AnonymousAggregationExpr {
765
fn as_expression(&self) -> Option<&Expr> {
766
None
767
}
768
769
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
770
polars_ensure!(
771
self.inputs.len() == 1,
772
ComputeError: "AnonymousAggregationExpr with more than one input is not supported"
773
);
774
775
let col = self.inputs[0].evaluate(df, state)?;
776
let mut gr = self.grouped_reduction.new_empty();
777
gr.resize(1);
778
gr.update_group(&[&col], 0, 0)?;
779
let out_series = gr.finalize()?;
780
Ok(Column::new(col.name().clone(), out_series))
781
}
782
783
#[allow(clippy::ptr_arg)]
784
fn evaluate_on_groups<'a>(
785
&self,
786
df: &DataFrame,
787
groups: &'a GroupPositions,
788
state: &ExecutionState,
789
) -> PolarsResult<AggregationContext<'a>> {
790
polars_ensure!(
791
self.inputs.len() == 1,
792
ComputeError: "AnonymousAggregationExpr with more than one input is not supported"
793
);
794
795
let input = &self.inputs[0];
796
let mut ac = input.evaluate_on_groups(df, groups, state)?;
797
798
// don't change names by aggregations as is done in polars-core
799
let input_column_name = ac.get_values().name().clone();
800
801
if let AggState::LiteralScalar(input_column) = &mut ac.state {
802
*input_column = self.evaluate(df, state)?;
803
return Ok(ac);
804
}
805
806
let (input_column, resolved_groups) = ac.get_final_aggregation();
807
808
let mut gr = self.grouped_reduction.new_empty();
809
gr.resize(groups.len() as IdxSize);
810
811
assert!(
812
!resolved_groups.is_overlapping(),
813
"Aggregating with overlapping groups is a logic error"
814
);
815
816
let subset = (0..input_column.len() as IdxSize).collect::<Vec<IdxSize>>();
817
818
let mut group_idxs = Vec::with_capacity(input_column.len());
819
match &**resolved_groups {
820
GroupsType::Idx(group_indices) => {
821
group_idxs.resize(input_column.len(), 0);
822
for (group_idx, indices_in_group) in group_indices.all().iter().enumerate() {
823
for pos in indices_in_group.iter() {
824
group_idxs[*pos as usize] = group_idx as IdxSize;
825
}
826
}
827
},
828
GroupsType::Slice { groups, .. } => {
829
for (group_idx, [_start, len]) in groups.iter().enumerate() {
830
group_idxs.extend(std::iter::repeat_n(group_idx as IdxSize, *len as usize));
831
}
832
},
833
};
834
assert_eq!(group_idxs.len(), input_column.len());
835
836
// `update_groups_subset` needs a single chunk.
837
let input_column_rechunked = input_column.rechunk();
838
839
// Single call so no need to resolve ordering.
840
let seq_id = 0;
841
842
// SAFETY:
843
// - `subset` is in-bounds because it is 0..N
844
// - `group_idxs` is in-bounds because we checked that it matches `input_column.len()` *and*
845
// is filled with values <= `input_column.len()` since they are derived from it via
846
// `enumerate`.
847
unsafe {
848
gr.update_groups_subset(&[&input_column_rechunked], &subset, &group_idxs, seq_id)?;
849
}
850
851
let out_series = gr.finalize()?;
852
let out = AggregatedScalar(Column::new(input_column_name, out_series));
853
854
Ok(AggregationContext::from_agg_state(out, resolved_groups))
855
}
856
857
fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
858
Ok(self.output_field.clone())
859
}
860
861
fn is_scalar(&self) -> bool {
862
true
863
}
864
}
865
866
/// Simple wrapper to parallelize functions that can be divided over threads aggregated and
867
/// finally aggregated in the main thread. This can be done for sum, min, max, etc.
868
fn parallel_op_columns<F>(f: F, s: Column, allow_threading: bool) -> PolarsResult<Column>
869
where
870
F: Fn(Column) -> PolarsResult<Column> + Send + Sync,
871
{
872
// set during debug low so
873
// we mimic production size data behavior
874
#[cfg(debug_assertions)]
875
let thread_boundary = 0;
876
877
#[cfg(not(debug_assertions))]
878
let thread_boundary = 100_000;
879
880
// threading overhead/ splitting work stealing is costly..
881
882
if !allow_threading
883
|| s.len() < thread_boundary
884
|| POOL.current_thread_has_pending_tasks().unwrap_or(false)
885
{
886
return f(s);
887
}
888
let n_threads = POOL.current_num_threads();
889
let splits = _split_offsets(s.len(), n_threads);
890
891
let chunks = POOL.install(|| {
892
splits
893
.into_par_iter()
894
.map(|(offset, len)| {
895
let s = s.slice(offset as i64, len);
896
f(s)
897
})
898
.collect::<PolarsResult<Vec<_>>>()
899
})?;
900
901
let mut iter = chunks.into_iter();
902
let first = iter.next().unwrap();
903
let dtype = first.dtype();
904
let out = iter.fold(first.to_physical_repr(), |mut acc, s| {
905
acc.append(&s.to_physical_repr()).unwrap();
906
acc
907
});
908
909
unsafe { f(out.from_physical_unchecked(dtype).unwrap()) }
910
}
911
912