Path: blob/main/crates/polars-expr/src/reduce/convert.rs
8415 views
// use polars_core::error::feature_gated;1use polars_plan::prelude::*;2use polars_utils::arena::{Arena, Node};34use super::*;5use crate::reduce::any_all::{new_all_reduction, new_any_reduction};6#[cfg(feature = "approx_unique")]7use crate::reduce::approx_n_unique::new_approx_n_unique_reduction;8#[cfg(feature = "bitwise")]9use crate::reduce::bitwise::{10new_bitwise_and_reduction, new_bitwise_or_reduction, new_bitwise_xor_reduction,11};12use crate::reduce::count::{CountReduce, NullCountReduce};13use crate::reduce::first_last::{new_first_reduction, new_item_reduction, new_last_reduction};14use crate::reduce::first_last_nonnull::{new_first_nonnull_reduction, new_last_nonnull_reduction};15use crate::reduce::len::LenReduce;16use crate::reduce::mean::new_mean_reduction;17use crate::reduce::min_max::{new_max_reduction, new_min_reduction};18use crate::reduce::min_max_by::{new_max_by_reduction, new_min_by_reduction};19use crate::reduce::sum::new_sum_reduction;20use crate::reduce::var_std::new_var_std_reduction;2122/// Converts a node into a reduction + its associated selector expression.23pub fn into_reduction(24node: Node,25expr_arena: &mut Arena<AExpr>,26schema: &Schema,27is_aggregation_context: bool,28) -> PolarsResult<(Box<dyn GroupedReduction>, Vec<Node>)> {29let get_dt = |node| {30expr_arena31.get(node)32.to_dtype(&ToFieldContext::new(expr_arena, schema))?33.materialize_unknown(false)34};35let (gr, in_node) = match expr_arena.get(node) {36AExpr::Agg(agg) => match agg {37IRAggExpr::Sum(input) => (new_sum_reduction(get_dt(*input)?)?, *input),38IRAggExpr::Mean(input) => (new_mean_reduction(get_dt(*input)?)?, *input),39IRAggExpr::Min {40propagate_nans,41input,42} => (new_min_reduction(get_dt(*input)?, *propagate_nans)?, *input),43IRAggExpr::Max {44propagate_nans,45input,46} => (new_max_reduction(get_dt(*input)?, *propagate_nans)?, *input),47IRAggExpr::Var(input, ddof) => (48new_var_std_reduction(get_dt(*input)?, false, *ddof)?,49*input,50),51IRAggExpr::Std(input, ddof) => {52(new_var_std_reduction(get_dt(*input)?, true, *ddof)?, *input)53},54IRAggExpr::First(input) => (new_first_reduction(get_dt(*input)?), *input),55IRAggExpr::FirstNonNull(input) => {56(new_first_nonnull_reduction(get_dt(*input)?), *input)57},58IRAggExpr::Last(input) => (new_last_reduction(get_dt(*input)?), *input),59IRAggExpr::LastNonNull(input) => (new_last_nonnull_reduction(get_dt(*input)?), *input),60IRAggExpr::Item { input, allow_empty } => {61(new_item_reduction(get_dt(*input)?, *allow_empty), *input)62},63IRAggExpr::Count {64input,65include_nulls,66} => {67let count = Box::new(CountReduce::new(*include_nulls)) as Box<_>;68(count, *input)69},70IRAggExpr::Quantile { .. } => todo!(),71IRAggExpr::Median(_) => todo!(),72IRAggExpr::NUnique(_) => todo!(),73IRAggExpr::Implode(_) => todo!(),74IRAggExpr::AggGroups(_) => todo!(),75},76AExpr::Len => {77if let Some(first_column) = schema.iter_names().next() {78let out: Box<dyn GroupedReduction> = Box::new(LenReduce::default());79let expr = expr_arena.add(AExpr::Column(first_column.as_str().into()));8081(out, expr)82} else {83// Support len aggregation on 0-width morsels.84// Notes:85// * We do this instead of projecting a scalar, because scalar literals don't86// project to the height of the DataFrame (in the PhysicalExpr impl).87// * This approach is not sound for `update_groups()`, but currently that case is88// not hit (it would need group-by -> len on empty morsels).89polars_ensure!(90!is_aggregation_context,91ComputeError:92"not implemented: len() of groups with no columns"93);9495let out: Box<dyn GroupedReduction> = new_sum_reduction(DataType::IDX_DTYPE)?;96let expr = expr_arena.add(AExpr::Len);9798(out, expr)99}100},101102AExpr::Function {103input: inner_exprs,104function: IRFunctionExpr::NullCount,105options: _,106} => {107assert!(inner_exprs.len() == 1);108let input = inner_exprs[0].node();109let count = Box::new(NullCountReduce::new()) as Box<_>;110(count, input)111},112113#[cfg(feature = "approx_unique")]114AExpr::Function {115input: inner_exprs,116function: IRFunctionExpr::ApproxNUnique,117options: _,118} => {119assert!(inner_exprs.len() == 1);120let input = inner_exprs[0].node();121let out = new_approx_n_unique_reduction(get_dt(input)?)?;122(out, input)123},124125#[cfg(feature = "bitwise")]126AExpr::Function {127input: inner_exprs,128function: IRFunctionExpr::Bitwise(inner_fn),129options: _,130} => {131assert!(inner_exprs.len() == 1);132let input = inner_exprs[0].node();133match inner_fn {134IRBitwiseFunction::And => (new_bitwise_and_reduction(get_dt(input)?), input),135IRBitwiseFunction::Or => (new_bitwise_or_reduction(get_dt(input)?), input),136IRBitwiseFunction::Xor => (new_bitwise_xor_reduction(get_dt(input)?), input),137_ => unreachable!(),138}139},140141AExpr::Function {142input: inner_exprs,143function: IRFunctionExpr::Boolean(inner_fn),144options: _,145} => {146assert!(inner_exprs.len() == 1);147let input = inner_exprs[0].node();148match inner_fn {149IRBooleanFunction::Any { ignore_nulls } => {150(new_any_reduction(*ignore_nulls), input)151},152IRBooleanFunction::All { ignore_nulls } => {153(new_all_reduction(*ignore_nulls), input)154},155_ => unreachable!(),156}157},158159AExpr::Function {160input: inner_exprs,161function: IRFunctionExpr::MinBy,162options: _,163} => {164assert!(inner_exprs.len() == 2);165let input = inner_exprs[0].node();166let by = inner_exprs[1].node();167let gr = new_min_by_reduction(get_dt(input)?, get_dt(by)?)?;168return Ok((gr, vec![input, by]));169},170171AExpr::Function {172input: inner_exprs,173function: IRFunctionExpr::MaxBy,174options: _,175} => {176assert!(inner_exprs.len() == 2);177let input = inner_exprs[0].node();178let by = inner_exprs[1].node();179let gr = new_max_by_reduction(get_dt(input)?, get_dt(by)?)?;180return Ok((gr, vec![input, by]));181},182183AExpr::AnonymousAgg {184input: inner_exprs,185fmt_str: _,186function,187} => {188let ann_agg = function.materialize()?;189assert!(inner_exprs.len() == 1);190let input = inner_exprs[0].node();191let reduction = ann_agg.as_any();192let reduction = reduction193.downcast_ref::<Box<dyn GroupedReduction>>()194.unwrap();195(reduction.new_empty(), input)196},197_ => unreachable!(),198};199Ok((gr, vec![in_node]))200}201202203