Path: blob/main/crates/polars-expr/src/reduce/convert.rs
6940 views
// use polars_core::error::feature_gated;1use polars_plan::prelude::*;2use polars_utils::arena::{Arena, Node};34use super::*;5use crate::reduce::any_all::{new_all_reduction, new_any_reduction};6#[cfg(feature = "bitwise")]7use crate::reduce::bitwise::{8new_bitwise_and_reduction, new_bitwise_or_reduction, new_bitwise_xor_reduction,9};10use crate::reduce::count::CountReduce;11use crate::reduce::first_last::{new_first_reduction, new_last_reduction};12use crate::reduce::len::LenReduce;13use crate::reduce::mean::new_mean_reduction;14use crate::reduce::min_max::{new_max_reduction, new_min_reduction};15use crate::reduce::sum::new_sum_reduction;16use crate::reduce::var_std::new_var_std_reduction;1718/// Converts a node into a reduction + its associated selector expression.19pub fn into_reduction(20node: Node,21expr_arena: &mut Arena<AExpr>,22schema: &Schema,23) -> PolarsResult<(Box<dyn GroupedReduction>, Node)> {24let get_dt = |node| {25expr_arena26.get(node)27.to_dtype(schema, expr_arena)?28.materialize_unknown(false)29};30let out = match expr_arena.get(node) {31AExpr::Agg(agg) => match agg {32IRAggExpr::Sum(input) => (new_sum_reduction(get_dt(*input)?), *input),33IRAggExpr::Mean(input) => (new_mean_reduction(get_dt(*input)?), *input),34IRAggExpr::Min {35propagate_nans,36input,37} => (new_min_reduction(get_dt(*input)?, *propagate_nans), *input),38IRAggExpr::Max {39propagate_nans,40input,41} => (new_max_reduction(get_dt(*input)?, *propagate_nans), *input),42IRAggExpr::Var(input, ddof) => {43(new_var_std_reduction(get_dt(*input)?, false, *ddof), *input)44},45IRAggExpr::Std(input, ddof) => {46(new_var_std_reduction(get_dt(*input)?, true, *ddof), *input)47},48IRAggExpr::First(input) => (new_first_reduction(get_dt(*input)?), *input),49IRAggExpr::Last(input) => (new_last_reduction(get_dt(*input)?), *input),50IRAggExpr::Count {51input,52include_nulls,53} => {54let count = Box::new(CountReduce::new(*include_nulls)) as Box<_>;55(count, *input)56},57IRAggExpr::Quantile { .. } => todo!(),58IRAggExpr::Median(_) => todo!(),59IRAggExpr::NUnique(_) => todo!(),60IRAggExpr::Implode(_) => todo!(),61IRAggExpr::AggGroups(_) => todo!(),62},63AExpr::Len => {64if let Some(first_column) = schema.iter_names().next() {65let out: Box<dyn GroupedReduction> = Box::new(LenReduce::default());66let expr = expr_arena.add(AExpr::Column(first_column.as_str().into()));6768(out, expr)69} else {70// Support len aggregation on 0-width morsels.71// Notes:72// * We do this instead of projecting a scalar, because scalar literals don't73// project to the height of the DataFrame (in the PhysicalExpr impl).74// * This approach is not sound for `update_groups()`, but currently that case is75// not hit (it would need group-by -> len on empty morsels).76let out: Box<dyn GroupedReduction> = new_sum_reduction(DataType::IDX_DTYPE);77let expr = expr_arena.add(AExpr::Len);7879(out, expr)80}81},82#[cfg(feature = "bitwise")]83AExpr::Function {84input: inner_exprs,85function: IRFunctionExpr::Bitwise(inner_fn),86options: _,87} => {88assert!(inner_exprs.len() == 1);89let input = inner_exprs[0].node();90match inner_fn {91IRBitwiseFunction::And => (new_bitwise_and_reduction(get_dt(input)?), input),92IRBitwiseFunction::Or => (new_bitwise_or_reduction(get_dt(input)?), input),93IRBitwiseFunction::Xor => (new_bitwise_xor_reduction(get_dt(input)?), input),94_ => unreachable!(),95}96},9798AExpr::Function {99input: inner_exprs,100function: IRFunctionExpr::Boolean(inner_fn),101options: _,102} => {103assert!(inner_exprs.len() == 1);104let input = inner_exprs[0].node();105match inner_fn {106IRBooleanFunction::Any { ignore_nulls } => {107(new_any_reduction(*ignore_nulls), input)108},109IRBooleanFunction::All { ignore_nulls } => {110(new_all_reduction(*ignore_nulls), input)111},112_ => unreachable!(),113}114},115_ => unreachable!(),116};117Ok(out)118}119120121