Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/count.rs
6940 views
1
use std::borrow::Cow;
2
3
use polars_core::prelude::*;
4
use polars_plan::constants::LEN;
5
6
use super::*;
7
use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr};
8
9
pub struct CountExpr {
10
expr: Expr,
11
}
12
13
impl CountExpr {
14
pub(crate) fn new() -> Self {
15
Self { expr: Expr::Len }
16
}
17
}
18
19
impl PhysicalExpr for CountExpr {
20
fn as_expression(&self) -> Option<&Expr> {
21
Some(&self.expr)
22
}
23
24
fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column> {
25
Ok(Column::new_scalar(
26
PlSmallStr::from_static(LEN),
27
Scalar::from(df.height() as IdxSize),
28
1,
29
))
30
}
31
32
fn evaluate_on_groups<'a>(
33
&self,
34
_df: &DataFrame,
35
groups: &'a GroupPositions,
36
_state: &ExecutionState,
37
) -> PolarsResult<AggregationContext<'a>> {
38
let ca = groups.group_count().with_name(PlSmallStr::from_static(LEN));
39
let c = ca.into_column();
40
Ok(AggregationContext::new(c, Cow::Borrowed(groups), true))
41
}
42
43
fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
44
Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE))
45
}
46
47
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
48
Some(self)
49
}
50
51
fn is_scalar(&self) -> bool {
52
true
53
}
54
}
55
56
impl PartitionedAggregation for CountExpr {
57
#[allow(clippy::ptr_arg)]
58
fn evaluate_partitioned(
59
&self,
60
df: &DataFrame,
61
groups: &GroupPositions,
62
state: &ExecutionState,
63
) -> PolarsResult<Column> {
64
self.evaluate_on_groups(df, groups, state)
65
.map(|mut ac| ac.aggregated().into_column())
66
}
67
68
/// Called to merge all the partitioned results in a final aggregate.
69
#[allow(clippy::ptr_arg)]
70
fn finalize(
71
&self,
72
partitioned: Column,
73
groups: &GroupPositions,
74
_state: &ExecutionState,
75
) -> PolarsResult<Column> {
76
// SAFETY: groups are in bounds.
77
let agg = unsafe { partitioned.agg_sum(groups) };
78
Ok(agg.with_name(PlSmallStr::from_static(LEN)))
79
}
80
}
81
82