Path: blob/main/crates/polars-expr/src/expressions/aggregation.rs
6940 views
use std::borrow::Cow;12use arrow::array::*;3use arrow::compute::concatenate::concatenate;4use arrow::legacy::utils::CustomIterTools;5use arrow::offset::Offsets;6use polars_compute::rolling::QuantileMethod;7use polars_core::POOL;8use polars_core::prelude::*;9use polars_core::series::IsSorted;10use polars_core::utils::{_split_offsets, NoNull};11#[cfg(feature = "propagate_nans")]12use polars_ops::prelude::nan_propagating_aggregate;13use rayon::prelude::*;1415use super::*;16use crate::expressions::AggState::AggregatedScalar;17use crate::expressions::{18AggState, AggregationContext, PartitionedAggregation, PhysicalExpr, UpdateGroups,19};2021#[derive(Debug, Clone, Copy)]22pub struct AggregationType {23pub(crate) groupby: GroupByMethod,24pub(crate) allow_threading: bool,25}2627pub(crate) struct AggregationExpr {28pub(crate) input: Arc<dyn PhysicalExpr>,29pub(crate) agg_type: AggregationType,30field: Option<Field>,31}3233impl AggregationExpr {34pub fn new(35expr: Arc<dyn PhysicalExpr>,36agg_type: AggregationType,37field: Option<Field>,38) -> Self {39Self {40input: expr,41agg_type,42field,43}44}45}4647impl PhysicalExpr for AggregationExpr {48fn as_expression(&self) -> Option<&Expr> {49None50}5152fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {53let s = self.input.evaluate(df, state)?;5455let AggregationType {56groupby,57allow_threading,58} = self.agg_type;5960let is_float = s.dtype().is_float();61let group_by = match groupby {62GroupByMethod::NanMin if !is_float => GroupByMethod::Min,63GroupByMethod::NanMax if !is_float => GroupByMethod::Max,64gb => gb,65};6667match group_by {68GroupByMethod::Min => match s.is_sorted_flag() {69IsSorted::Ascending | IsSorted::Descending => {70s.min_reduce().map(|sc| sc.into_column(s.name().clone()))71},72IsSorted::Not => parallel_op_columns(73|s| s.min_reduce().map(|sc| sc.into_column(s.name().clone())),74s,75allow_threading,76),77},78#[cfg(feature = "propagate_nans")]79GroupByMethod::NanMin => parallel_op_columns(80|s| {81Ok(polars_ops::prelude::nan_propagating_aggregate::nan_min_s(82s.as_materialized_series(),83s.name().clone(),84)85.into_column())86},87s,88allow_threading,89),90#[cfg(not(feature = "propagate_nans"))]91GroupByMethod::NanMin => {92panic!("activate 'propagate_nans' feature")93},94GroupByMethod::Max => match s.is_sorted_flag() {95IsSorted::Ascending | IsSorted::Descending => {96s.max_reduce().map(|sc| sc.into_column(s.name().clone()))97},98IsSorted::Not => parallel_op_columns(99|s| s.max_reduce().map(|sc| sc.into_column(s.name().clone())),100s,101allow_threading,102),103},104#[cfg(feature = "propagate_nans")]105GroupByMethod::NanMax => parallel_op_columns(106|s| {107Ok(polars_ops::prelude::nan_propagating_aggregate::nan_max_s(108s.as_materialized_series(),109s.name().clone(),110)111.into_column())112},113s,114allow_threading,115),116#[cfg(not(feature = "propagate_nans"))]117GroupByMethod::NanMax => {118panic!("activate 'propagate_nans' feature")119},120GroupByMethod::Median => s.median_reduce().map(|sc| sc.into_column(s.name().clone())),121GroupByMethod::Mean => Ok(s.mean_reduce().into_column(s.name().clone())),122GroupByMethod::First => Ok(if s.is_empty() {123Column::full_null(s.name().clone(), 1, s.dtype())124} else {125s.head(Some(1))126}),127GroupByMethod::Last => Ok(if s.is_empty() {128Column::full_null(s.name().clone(), 1, s.dtype())129} else {130s.tail(Some(1))131}),132GroupByMethod::Sum => parallel_op_columns(133|s| s.sum_reduce().map(|sc| sc.into_column(s.name().clone())),134s,135allow_threading,136),137GroupByMethod::Groups => unreachable!(),138GroupByMethod::NUnique => s.n_unique().map(|count| {139IdxCa::from_slice(s.name().clone(), &[count as IdxSize]).into_column()140}),141GroupByMethod::Count { include_nulls } => {142let count = s.len() - s.null_count() * !include_nulls as usize;143144Ok(IdxCa::from_slice(s.name().clone(), &[count as IdxSize]).into_column())145},146GroupByMethod::Implode => s.implode().map(|ca| ca.into_column()),147GroupByMethod::Std(ddof) => s148.std_reduce(ddof)149.map(|sc| sc.into_column(s.name().clone())),150GroupByMethod::Var(ddof) => s151.var_reduce(ddof)152.map(|sc| sc.into_column(s.name().clone())),153GroupByMethod::Quantile(_, _) => unimplemented!(),154}155}156#[allow(clippy::ptr_arg)]157fn evaluate_on_groups<'a>(158&self,159df: &DataFrame,160groups: &'a GroupPositions,161state: &ExecutionState,162) -> PolarsResult<AggregationContext<'a>> {163let mut ac = self.input.evaluate_on_groups(df, groups, state)?;164// don't change names by aggregations as is done in polars-core165let keep_name = ac.get_values().name().clone();166167// Literals cannot be aggregated except for implode.168polars_ensure!(!matches!(ac.agg_state(), AggState::LiteralScalar(_)), ComputeError: "cannot aggregate a literal");169170if let AggregatedScalar(_) = ac.agg_state() {171match self.agg_type.groupby {172GroupByMethod::Implode => {},173_ => {174polars_bail!(ComputeError: "cannot aggregate as {}, the column is already aggregated", self.agg_type.groupby);175},176}177}178179// SAFETY:180// groups must always be in bounds.181let out = unsafe {182match self.agg_type.groupby {183GroupByMethod::Min => {184let (c, groups) = ac.get_final_aggregation();185let agg_c = c.agg_min(&groups);186AggregatedScalar(agg_c.with_name(keep_name))187},188GroupByMethod::Max => {189let (c, groups) = ac.get_final_aggregation();190let agg_c = c.agg_max(&groups);191AggregatedScalar(agg_c.with_name(keep_name))192},193GroupByMethod::Median => {194let (c, groups) = ac.get_final_aggregation();195let agg_c = c.agg_median(&groups);196AggregatedScalar(agg_c.with_name(keep_name))197},198GroupByMethod::Mean => {199let (c, groups) = ac.get_final_aggregation();200let agg_c = c.agg_mean(&groups);201AggregatedScalar(agg_c.with_name(keep_name))202},203GroupByMethod::Sum => {204let (c, groups) = ac.get_final_aggregation();205let agg_c = c.agg_sum(&groups);206AggregatedScalar(agg_c.with_name(keep_name))207},208GroupByMethod::Count { include_nulls } => {209if include_nulls || ac.get_values().null_count() == 0 {210// a few fast paths that prevent materializing new groups211match ac.update_groups {212UpdateGroups::WithSeriesLen => {213let list = ac214.get_values()215.list()216.expect("impl error, should be a list at this point");217218let mut s = match list.chunks().len() {2191 => {220let arr = list.downcast_iter().next().unwrap();221let offsets = arr.offsets().as_slice();222223let mut previous = 0i64;224let counts: NoNull<IdxCa> = offsets[1..]225.iter()226.map(|&o| {227let len = (o - previous) as IdxSize;228previous = o;229len230})231.collect_trusted();232counts.into_inner()233},234_ => {235let counts: NoNull<IdxCa> = list236.amortized_iter()237.map(|s| {238if let Some(s) = s {239s.as_ref().len() as IdxSize240} else {2411242}243})244.collect_trusted();245counts.into_inner()246},247};248s.rename(keep_name);249AggregatedScalar(s.into_column())250},251UpdateGroups::WithGroupsLen => {252// no need to update the groups253// we can just get the attribute, because we only need the length,254// not the correct order255let mut ca = ac.groups.group_count();256ca.rename(keep_name);257AggregatedScalar(ca.into_column())258},259// materialize groups260_ => {261let mut ca = ac.groups().group_count();262ca.rename(keep_name);263AggregatedScalar(ca.into_column())264},265}266} else {267// TODO: optimize this/and write somewhere else.268match ac.agg_state() {269AggState::LiteralScalar(s) | AggState::AggregatedScalar(s) => {270AggregatedScalar(Column::new(271keep_name,272[(s.len() as IdxSize - s.null_count() as IdxSize)],273))274},275AggState::AggregatedList(s) => {276let ca = s.list()?;277let out: IdxCa = ca278.into_iter()279.map(|opt_s| {280opt_s281.map(|s| s.len() as IdxSize - s.null_count() as IdxSize)282})283.collect();284AggregatedScalar(out.into_column().with_name(keep_name))285},286AggState::NotAggregated(s) => {287let s = s.clone();288let groups = ac.groups();289let out: IdxCa = if matches!(s.dtype(), &DataType::Null) {290IdxCa::full(s.name().clone(), 0, groups.len())291} else {292match groups.as_ref().as_ref() {293GroupsType::Idx(idx) => {294let s = s.rechunk();295// @scalar-opt296// @partition-opt297let array = &s.as_materialized_series().chunks()[0];298let validity = array.validity().unwrap();299idx.iter()300.map(|(_, g)| {301let mut count = 0 as IdxSize;302// Count valid values303g.iter().for_each(|i| {304count += validity305.get_bit_unchecked(*i as usize)306as IdxSize;307});308count309})310.collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE)311},312GroupsType::Slice { groups, .. } => {313// Slice and use computed null count314groups315.iter()316.map(|g| {317let start = g[0];318let len = g[1];319len - s320.slice(start as i64, len as usize)321.null_count()322as IdxSize323})324.collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE)325},326}327};328AggregatedScalar(out.into_column())329},330}331}332},333GroupByMethod::First => {334let (s, groups) = ac.get_final_aggregation();335let agg_s = s.agg_first(&groups);336AggregatedScalar(agg_s.with_name(keep_name))337},338GroupByMethod::Last => {339let (s, groups) = ac.get_final_aggregation();340let agg_s = s.agg_last(&groups);341AggregatedScalar(agg_s.with_name(keep_name))342},343GroupByMethod::NUnique => {344let (s, groups) = ac.get_final_aggregation();345let agg_s = s.agg_n_unique(&groups);346AggregatedScalar(agg_s.with_name(keep_name))347},348GroupByMethod::Implode => {349// If the aggregation is already in an aggregate flat state (AggregatedScalar), for instance by350// a mean() aggregation, we simply wrap into a list and maintain the AggregatedScalar state351//352// If it is not, we traverse the groups and create a list per group.353let c = match ac.agg_state() {354// mean agg:355// -> f64 -> list<f64>356AggregatedScalar(c) => c357.cast(&DataType::List(Box::new(c.dtype().clone())))358.unwrap(),359// Auto-imploded360AggState::NotAggregated(_) | AggState::AggregatedList(_) => {361ac._implode_no_agg();362return Ok(ac);363},364_ => {365let agg = ac.aggregated();366agg.as_list().into_column()367},368};369match ac.agg_state() {370// An imploded scalar remains a scalar371AggregatedScalar(_) => AggregatedScalar(c.with_name(keep_name)),372_ => AggState::AggregatedList(c.with_name(keep_name)),373}374},375GroupByMethod::Groups => {376let mut column: ListChunked = ac.groups().as_list_chunked();377column.rename(keep_name);378AggregatedScalar(column.into_column())379},380GroupByMethod::Std(ddof) => {381let (c, groups) = ac.get_final_aggregation();382let agg_c = c.agg_std(&groups, ddof);383AggregatedScalar(agg_c.with_name(keep_name))384},385GroupByMethod::Var(ddof) => {386let (c, groups) = ac.get_final_aggregation();387let agg_c = c.agg_var(&groups, ddof);388AggregatedScalar(agg_c.with_name(keep_name))389},390GroupByMethod::Quantile(_, _) => {391// implemented explicitly in AggQuantile struct392unimplemented!()393},394GroupByMethod::NanMin => {395#[cfg(feature = "propagate_nans")]396{397let (c, groups) = ac.get_final_aggregation();398let agg_c = if c.dtype().is_float() {399nan_propagating_aggregate::group_agg_nan_min_s(400c.as_materialized_series(),401&groups,402)403.into_column()404} else {405c.agg_min(&groups)406};407AggregatedScalar(agg_c.with_name(keep_name))408}409#[cfg(not(feature = "propagate_nans"))]410{411panic!("activate 'propagate_nans' feature")412}413},414GroupByMethod::NanMax => {415#[cfg(feature = "propagate_nans")]416{417let (c, groups) = ac.get_final_aggregation();418let agg_c = if c.dtype().is_float() {419nan_propagating_aggregate::group_agg_nan_max_s(420c.as_materialized_series(),421&groups,422)423.into_column()424} else {425c.agg_max(&groups)426};427AggregatedScalar(agg_c.with_name(keep_name))428}429#[cfg(not(feature = "propagate_nans"))]430{431panic!("activate 'propagate_nans' feature")432}433},434}435};436437Ok(AggregationContext::from_agg_state(438out,439Cow::Borrowed(groups),440))441}442443fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {444if let Some(field) = self.field.as_ref() {445Ok(field.clone())446} else {447self.input.to_field(input_schema)448}449}450451fn is_scalar(&self) -> bool {452true453}454455fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {456Some(self)457}458}459460impl PartitionedAggregation for AggregationExpr {461fn evaluate_partitioned(462&self,463df: &DataFrame,464groups: &GroupPositions,465state: &ExecutionState,466) -> PolarsResult<Column> {467let expr = self.input.as_partitioned_aggregator().unwrap();468let column = expr.evaluate_partitioned(df, groups, state)?;469470// SAFETY:471// groups are in bounds472unsafe {473match self.agg_type.groupby {474#[cfg(feature = "dtype-struct")]475GroupByMethod::Mean => {476let new_name = column.name().clone();477478// ensure we don't overflow479// the all 8 and 16 bits integers are already upcasted to int16 on `agg_sum`480let mut agg_s = if matches!(column.dtype(), DataType::Int32 | DataType::UInt32)481{482column.cast(&DataType::Int64).unwrap().agg_sum(groups)483} else {484column.agg_sum(groups)485};486agg_s.rename(new_name.clone());487488if !agg_s.dtype().is_primitive_numeric() {489Ok(agg_s)490} else {491let agg_s = match agg_s.dtype() {492DataType::Float32 => agg_s,493_ => agg_s.cast(&DataType::Float64).unwrap(),494};495let mut count_s = column.agg_valid_count(groups);496count_s.rename(PlSmallStr::from_static("__POLARS_COUNT"));497Ok(498StructChunked::from_columns(new_name, agg_s.len(), &[agg_s, count_s])499.unwrap()500.into_column(),501)502}503},504GroupByMethod::Implode => {505let new_name = column.name().clone();506let mut agg = column.agg_list(groups);507agg.rename(new_name);508Ok(agg)509},510GroupByMethod::First => {511let mut agg = column.agg_first(groups);512agg.rename(column.name().clone());513Ok(agg)514},515GroupByMethod::Last => {516let mut agg = column.agg_last(groups);517agg.rename(column.name().clone());518Ok(agg)519},520GroupByMethod::Max => {521let mut agg = column.agg_max(groups);522agg.rename(column.name().clone());523Ok(agg)524},525GroupByMethod::Min => {526let mut agg = column.agg_min(groups);527agg.rename(column.name().clone());528Ok(agg)529},530GroupByMethod::Sum => {531let mut agg = column.agg_sum(groups);532agg.rename(column.name().clone());533Ok(agg)534},535GroupByMethod::Count {536include_nulls: true,537} => {538let mut ca = groups.group_count();539ca.rename(column.name().clone());540Ok(ca.into_column())541},542_ => {543unimplemented!()544},545}546}547}548549fn finalize(550&self,551partitioned: Column,552groups: &GroupPositions,553_state: &ExecutionState,554) -> PolarsResult<Column> {555match self.agg_type.groupby {556GroupByMethod::Count {557include_nulls: true,558}559| GroupByMethod::Sum => {560let mut agg = unsafe { partitioned.agg_sum(groups) };561agg.rename(partitioned.name().clone());562Ok(agg)563},564#[cfg(feature = "dtype-struct")]565GroupByMethod::Mean => {566let new_name = partitioned.name().clone();567match partitioned.dtype() {568DataType::Struct(_) => {569let ca = partitioned.struct_().unwrap();570let fields = ca.fields_as_series();571let sum = &fields[0];572let count = &fields[1];573let (agg_count, agg_s) =574unsafe { POOL.join(|| count.agg_sum(groups), || sum.agg_sum(groups)) };575576// Ensure that we don't divide by zero by masking out zeros.577let agg_count = agg_count.idx().unwrap();578let mask = agg_count.equal(0 as IdxSize);579let agg_count = agg_count.set(&mask, None).unwrap().into_series();580581let agg_s = &agg_s / &agg_count.cast(agg_s.dtype()).unwrap();582Ok(agg_s?.with_name(new_name).into_column())583},584_ => Ok(Column::full_null(585new_name,586groups.len(),587partitioned.dtype(),588)),589}590},591GroupByMethod::Implode => {592// the groups are scattered over multiple groups/sub dataframes.593// we now must collect them into a single group594let ca = partitioned.list().unwrap();595let new_name = partitioned.name().clone();596597let mut values = Vec::with_capacity(groups.len());598let mut can_fast_explode = true;599600let mut offsets = Vec::<i64>::with_capacity(groups.len() + 1);601let mut length_so_far = 0i64;602offsets.push(length_so_far);603604let mut process_group = |ca: ListChunked| -> PolarsResult<()> {605let s = ca.explode(false)?;606length_so_far += s.len() as i64;607offsets.push(length_so_far);608values.push(s.chunks()[0].clone());609610if s.is_empty() {611can_fast_explode = false;612}613Ok(())614};615616match groups.as_ref() {617GroupsType::Idx(groups) => {618for (_, idx) in groups {619let ca = unsafe {620// SAFETY:621// The indexes of the group_by operation are never out of bounds622ca.take_unchecked(idx)623};624process_group(ca)?;625}626},627GroupsType::Slice { groups, .. } => {628for [first, len] in groups {629let len = *len as usize;630let ca = ca.slice(*first as i64, len);631process_group(ca)?;632}633},634}635636let vals = values.iter().map(|arr| &**arr).collect::<Vec<_>>();637let values = concatenate(&vals).unwrap();638639let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());640// SAFETY: offsets are monotonically increasing.641let arr = ListArray::<i64>::new(642dtype,643unsafe { Offsets::new_unchecked(offsets).into() },644values,645None,646);647let mut ca = ListChunked::with_chunk(new_name, arr);648if can_fast_explode {649ca.set_fast_explode()650}651Ok(ca.into_series().as_list().into_column())652},653GroupByMethod::First => {654let mut agg = unsafe { partitioned.agg_first(groups) };655agg.rename(partitioned.name().clone());656Ok(agg)657},658GroupByMethod::Last => {659let mut agg = unsafe { partitioned.agg_last(groups) };660agg.rename(partitioned.name().clone());661Ok(agg)662},663GroupByMethod::Max => {664let mut agg = unsafe { partitioned.agg_max(groups) };665agg.rename(partitioned.name().clone());666Ok(agg)667},668GroupByMethod::Min => {669let mut agg = unsafe { partitioned.agg_min(groups) };670agg.rename(partitioned.name().clone());671Ok(agg)672},673_ => unimplemented!(),674}675}676}677678pub struct AggQuantileExpr {679pub(crate) input: Arc<dyn PhysicalExpr>,680pub(crate) quantile: Arc<dyn PhysicalExpr>,681pub(crate) method: QuantileMethod,682}683684impl AggQuantileExpr {685pub fn new(686input: Arc<dyn PhysicalExpr>,687quantile: Arc<dyn PhysicalExpr>,688method: QuantileMethod,689) -> Self {690Self {691input,692quantile,693method,694}695}696697fn get_quantile(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<f64> {698let quantile = self.quantile.evaluate(df, state)?;699polars_ensure!(quantile.len() <= 1, ComputeError:700"polars only supports computing a single quantile; \701make sure the 'quantile' expression input produces a single quantile"702);703quantile.get(0).unwrap().try_extract()704}705}706707impl PhysicalExpr for AggQuantileExpr {708fn as_expression(&self) -> Option<&Expr> {709None710}711712fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {713let input = self.input.evaluate(df, state)?;714let quantile = self.get_quantile(df, state)?;715input716.quantile_reduce(quantile, self.method)717.map(|sc| sc.into_column(input.name().clone()))718}719#[allow(clippy::ptr_arg)]720fn evaluate_on_groups<'a>(721&self,722df: &DataFrame,723groups: &'a GroupPositions,724state: &ExecutionState,725) -> PolarsResult<AggregationContext<'a>> {726let mut ac = self.input.evaluate_on_groups(df, groups, state)?;727// don't change names by aggregations as is done in polars-core728let keep_name = ac.get_values().name().clone();729730let quantile = self.get_quantile(df, state)?;731732// SAFETY:733// groups are in bounds734let mut agg = unsafe {735ac.flat_naive()736.into_owned()737.agg_quantile(ac.groups(), quantile, self.method)738};739agg.rename(keep_name);740Ok(AggregationContext::from_agg_state(741AggregatedScalar(agg),742Cow::Borrowed(groups),743))744}745746fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {747self.input.to_field(input_schema)748}749750fn is_scalar(&self) -> bool {751true752}753}754755/// Simple wrapper to parallelize functions that can be divided over threads aggregated and756/// finally aggregated in the main thread. This can be done for sum, min, max, etc.757fn parallel_op_columns<F>(f: F, s: Column, allow_threading: bool) -> PolarsResult<Column>758where759F: Fn(Column) -> PolarsResult<Column> + Send + Sync,760{761// set during debug low so762// we mimic production size data behavior763#[cfg(debug_assertions)]764let thread_boundary = 0;765766#[cfg(not(debug_assertions))]767let thread_boundary = 100_000;768769// threading overhead/ splitting work stealing is costly..770771if !allow_threading772|| s.len() < thread_boundary773|| POOL.current_thread_has_pending_tasks().unwrap_or(false)774{775return f(s);776}777let n_threads = POOL.current_num_threads();778let splits = _split_offsets(s.len(), n_threads);779780let chunks = POOL.install(|| {781splits782.into_par_iter()783.map(|(offset, len)| {784let s = s.slice(offset as i64, len);785f(s)786})787.collect::<PolarsResult<Vec<_>>>()788})?;789790let mut iter = chunks.into_iter();791let first = iter.next().unwrap();792let dtype = first.dtype();793let out = iter.fold(first.to_physical_repr(), |mut acc, s| {794acc.append(&s.to_physical_repr()).unwrap();795acc796});797798unsafe { f(out.from_physical_unchecked(dtype).unwrap()) }799}800801802