Path: blob/main/crates/polars-stream/src/physical_plan/lower_group_by.rs
6939 views
use std::sync::Arc;12use parking_lot::Mutex;3use polars_core::frame::DataFrame;4use polars_core::prelude::{InitHashMaps, PlIndexMap, SortMultipleOptions};5use polars_core::schema::Schema;6use polars_error::{PolarsResult, polars_err};7use polars_expr::state::ExecutionState;8use polars_mem_engine::create_physical_plan;9use polars_plan::plans::expr_ir::{ExprIR, OutputName};10use polars_plan::plans::{AExpr, IR, IRAggExpr, IRFunctionExpr, NaiveExprMerger, write_group_by};11use polars_plan::prelude::{GroupbyOptions, *};12use polars_utils::arena::{Arena, Node};13use polars_utils::pl_str::PlSmallStr;14use polars_utils::unique_column_name;15use recursive::recursive;16use slotmap::SlotMap;1718use super::{ExprCache, PhysNode, PhysNodeKey, PhysNodeKind, PhysStream, StreamingLowerIRContext};19use crate::physical_plan::lower_expr::{20build_select_stream, compute_output_schema, is_elementwise_rec_cached,21is_fake_elementwise_function, is_input_independent,22};23use crate::physical_plan::lower_ir::{build_row_idx_stream, build_slice_stream};24use crate::utils::late_materialized_df::LateMaterializedDataFrame;2526#[allow(clippy::too_many_arguments)]27fn build_group_by_fallback(28input: PhysStream,29keys: &[ExprIR],30aggs: &[ExprIR],31output_schema: Arc<Schema>,32maintain_order: bool,33options: Arc<GroupbyOptions>,34apply: Option<PlanCallback<DataFrame, DataFrame>>,35expr_arena: &mut Arena<AExpr>,36phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,37format_str: Option<String>,38) -> PolarsResult<PhysStream> {39let input_schema = phys_sm[input.node].output_schema.clone();40let lmdf = Arc::new(LateMaterializedDataFrame::default());41let mut lp_arena = Arena::default();42let input_lp_node = lp_arena.add(lmdf.clone().as_ir_node(input_schema));43let group_by_lp_node = lp_arena.add(IR::GroupBy {44input: input_lp_node,45keys: keys.to_vec(),46aggs: aggs.to_vec(),47schema: output_schema.clone(),48maintain_order,49options,50apply,51});52let executor = Mutex::new(create_physical_plan(53group_by_lp_node,54&mut lp_arena,55expr_arena,56None,57)?);5859let group_by_node = PhysNode {60output_schema,61kind: PhysNodeKind::InMemoryMap {62input,63map: Arc::new(move |df| {64lmdf.set_materialized_dataframe(df);65let mut state = ExecutionState::new();66executor.lock().execute(&mut state)67}),68format_str,69},70};7172Ok(PhysStream::first(phys_sm.insert(group_by_node)))73}7475/// Tries to lower an expression as a 'elementwise scalar agg expression'.76///77/// Such an expression is defined as the elementwise combination of scalar78/// aggregations of elementwise combinations of the input columns / scalar literals.79#[recursive]80#[allow(clippy::too_many_arguments)]81fn try_lower_elementwise_scalar_agg_expr(82expr: Node,83outer_name: Option<PlSmallStr>,84expr_merger: &NaiveExprMerger,85expr_cache: &mut ExprCache,86expr_arena: &mut Arena<AExpr>,87agg_exprs: &mut Vec<ExprIR>,88uniq_input_exprs: &mut PlIndexMap<u32, PlSmallStr>,89uniq_agg_exprs: &mut PlIndexMap<u32, PlSmallStr>,90) -> Option<Node> {91// Helper macro to simplify recursive calls.92macro_rules! lower_rec {93($input:expr) => {94try_lower_elementwise_scalar_agg_expr(95$input,96None,97expr_merger,98expr_cache,99expr_arena,100agg_exprs,101uniq_input_exprs,102uniq_agg_exprs,103)104};105}106107match expr_arena.get(expr) {108AExpr::Column(_) => {109// Implicit implode not yet supported.110None111},112113AExpr::Literal(lit) => {114if lit.is_scalar() {115Some(expr)116} else {117None118}119},120121AExpr::Slice { .. }122| AExpr::Window { .. }123| AExpr::Sort { .. }124| AExpr::SortBy { .. }125| AExpr::Gather { .. } => None,126127// Explode and filter are row-separable and should thus in theory work128// in a streaming fashion but they change the length of the input which129// means the same filter/explode should also be applied to the key130// column, which is not (yet) supported.131AExpr::Explode { .. } | AExpr::Filter { .. } => None,132133AExpr::BinaryExpr { left, op, right } => {134let (left, op, right) = (*left, *op, *right);135let left = lower_rec!(left)?;136let right = lower_rec!(right)?;137Some(expr_arena.add(AExpr::BinaryExpr { left, op, right }))138},139140AExpr::Eval {141expr,142evaluation,143variant,144} => {145let (expr, evaluation, variant) = (*expr, *evaluation, *variant);146let expr = lower_rec!(expr)?;147Some(expr_arena.add(AExpr::Eval {148expr,149evaluation,150variant,151}))152},153154AExpr::Ternary {155predicate,156truthy,157falsy,158} => {159let (predicate, truthy, falsy) = (*predicate, *truthy, *falsy);160let predicate = lower_rec!(predicate)?;161let truthy = lower_rec!(truthy)?;162let falsy = lower_rec!(falsy)?;163Some(expr_arena.add(AExpr::Ternary {164predicate,165truthy,166falsy,167}))168},169170#[cfg(feature = "bitwise")]171AExpr::Function {172input: inner_exprs,173function:174IRFunctionExpr::Bitwise(175inner_fn @ (IRBitwiseFunction::And176| IRBitwiseFunction::Or177| IRBitwiseFunction::Xor),178),179options,180} => {181assert!(inner_exprs.len() == 1);182183let input = inner_exprs[0].clone().node();184let inner_fn = *inner_fn;185let options = *options;186187if is_input_independent(input, expr_arena, expr_cache) {188// TODO: we could simply return expr here, but we first need an is_scalar function, because if189// it is not a scalar we need to return expr.implode().190return None;191}192193if !is_elementwise_rec_cached(input, expr_arena, expr_cache) {194return None;195}196197let agg_id = expr_merger.get_uniq_id(expr).unwrap();198let name = uniq_agg_exprs199.entry(agg_id)200.or_insert_with(|| {201let input_id = expr_merger.get_uniq_id(input).unwrap();202let input_col = uniq_input_exprs203.entry(input_id)204.or_insert_with(unique_column_name)205.clone();206let input_col_node = expr_arena.add(AExpr::Column(input_col));207let trans_agg_node = expr_arena.add(AExpr::Function {208input: vec![ExprIR::from_node(input_col_node, expr_arena)],209function: IRFunctionExpr::Bitwise(inner_fn),210options,211});212213// Add to aggregation expressions and replace with a reference to its output.214let agg_expr = if let Some(name) = outer_name {215ExprIR::new(trans_agg_node, OutputName::Alias(name))216} else {217ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()))218};219agg_exprs.push(agg_expr.clone());220agg_expr.output_name().clone()221})222.clone();223let result_node = expr_arena.add(AExpr::Column(name));224Some(result_node)225},226227AExpr::Function {228input: inner_exprs,229function:230IRFunctionExpr::Boolean(231inner_fn @ (IRBooleanFunction::Any { .. } | IRBooleanFunction::All { .. }),232),233options,234} => {235assert!(inner_exprs.len() == 1);236237let input = inner_exprs[0].clone().node();238let inner_fn = inner_fn.clone();239let options = *options;240241if is_input_independent(input, expr_arena, expr_cache) {242// TODO: we could simply return expr here, but we first need an is_scalar function, because if243// it is not a scalar we need to return expr.implode().244return None;245}246247if !is_elementwise_rec_cached(input, expr_arena, expr_cache) {248return None;249}250251let agg_id = expr_merger.get_uniq_id(expr).unwrap();252let name = uniq_agg_exprs253.entry(agg_id)254.or_insert_with(|| {255let input_id = expr_merger.get_uniq_id(input).unwrap();256let input_col = uniq_input_exprs257.entry(input_id)258.or_insert_with(unique_column_name)259.clone();260let input_col_node = expr_arena.add(AExpr::Column(input_col));261let trans_agg_node = expr_arena.add(AExpr::Function {262input: vec![ExprIR::from_node(input_col_node, expr_arena)],263function: IRFunctionExpr::Boolean(inner_fn),264options,265});266267// Add to aggregation expressions and replace with a reference to its output.268let agg_expr = if let Some(name) = outer_name {269ExprIR::new(trans_agg_node, OutputName::Alias(name))270} else {271ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()))272};273agg_exprs.push(agg_expr.clone());274agg_expr.output_name().clone()275})276.clone();277let result_node = expr_arena.add(AExpr::Column(name));278Some(result_node)279},280281node @ AExpr::Function { input, options, .. }282| node @ AExpr::AnonymousFunction { input, options, .. }283if options.is_elementwise() && !is_fake_elementwise_function(node) =>284{285let node = node.clone();286let input = input.clone();287let new_input = input288.into_iter()289.map(|i| {290// The function may be sensitive to names (e.g. pl.struct), so we restore them.291let new_node = lower_rec!(i.node())?;292Some(ExprIR::new(293new_node,294OutputName::Alias(i.output_name().clone()),295))296})297.collect::<Option<Vec<_>>>()?;298299let mut new_node = node;300match &mut new_node {301AExpr::Function { input, .. } | AExpr::AnonymousFunction { input, .. } => {302*input = new_input;303},304_ => unreachable!(),305}306Some(expr_arena.add(new_node))307},308309AExpr::Function { .. } | AExpr::AnonymousFunction { .. } => None,310311AExpr::Cast {312expr,313dtype,314options,315} => {316let (expr, dtype, options) = (*expr, dtype.clone(), *options);317let expr = lower_rec!(expr)?;318Some(expr_arena.add(AExpr::Cast {319expr,320dtype,321options,322}))323},324325AExpr::Agg(agg) => {326match agg {327IRAggExpr::Min { input, .. }328| IRAggExpr::Max { input, .. }329| IRAggExpr::First(input)330| IRAggExpr::Last(input)331| IRAggExpr::Mean(input)332| IRAggExpr::Sum(input)333| IRAggExpr::Var(input, ..)334| IRAggExpr::Std(input, ..)335| IRAggExpr::Count { input, .. } => {336let agg = agg.clone();337let input = *input;338if is_input_independent(input, expr_arena, expr_cache) {339// TODO: we could simply return expr here, but we first need an is_scalar function, because if340// it is not a scalar we need to return expr.implode().341return None;342}343344if !is_elementwise_rec_cached(input, expr_arena, expr_cache) {345return None;346}347348let agg_id = expr_merger.get_uniq_id(expr).unwrap();349let name = uniq_agg_exprs350.entry(agg_id)351.or_insert_with(|| {352let mut trans_agg = agg;353let input_id = expr_merger.get_uniq_id(input).unwrap();354let input_col = uniq_input_exprs355.entry(input_id)356.or_insert_with(unique_column_name)357.clone();358let input_col_node = expr_arena.add(AExpr::Column(input_col));359trans_agg.set_input(input_col_node);360let trans_agg_node = expr_arena.add(AExpr::Agg(trans_agg));361362// Add to aggregation expressions and replace with a reference to its output.363let agg_expr = if let Some(name) = outer_name {364ExprIR::new(trans_agg_node, OutputName::Alias(name))365} else {366ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()))367};368agg_exprs.push(agg_expr.clone());369agg_expr.output_name().clone()370})371.clone();372373let result_node = expr_arena.add(AExpr::Column(name));374Some(result_node)375},376IRAggExpr::Median(..)377| IRAggExpr::NUnique(..)378| IRAggExpr::Implode(..)379| IRAggExpr::Quantile { .. }380| IRAggExpr::AggGroups(..) => None, // TODO: allow all aggregates,381}382},383AExpr::Len => {384let agg_expr = if let Some(name) = outer_name {385ExprIR::new(expr, OutputName::Alias(name))386} else {387ExprIR::new(expr, OutputName::Alias(unique_column_name()))388};389let result_node = expr_arena.add(AExpr::Column(agg_expr.output_name().clone()));390agg_exprs.push(agg_expr);391Some(result_node)392},393}394}395396#[allow(clippy::too_many_arguments)]397fn try_build_streaming_group_by(398mut input: PhysStream,399keys: &[ExprIR],400aggs: &[ExprIR],401maintain_order: bool,402options: Arc<GroupbyOptions>,403apply: Option<PlanCallback<DataFrame, DataFrame>>,404expr_arena: &mut Arena<AExpr>,405phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,406expr_cache: &mut ExprCache,407ctx: StreamingLowerIRContext,408) -> Option<PolarsResult<PhysStream>> {409if apply.is_some() {410return None; // TODO411}412413#[cfg(feature = "dynamic_group_by")]414if options.dynamic.is_some() || options.rolling.is_some() {415return None; // TODO416}417418if keys.is_empty() {419return Some(Err(420polars_err!(ComputeError: "at least one key is required in a group_by operation"),421));422}423424// Augment with row index if maintaining order.425let row_idx_name = unique_column_name();426let row_idx_node = expr_arena.add(AExpr::Column(row_idx_name.clone()));427let mut agg_storage;428let aggs = if maintain_order {429input = build_row_idx_stream(input, row_idx_name.clone(), None, phys_sm);430let first_agg_node = expr_arena.add(AExpr::Agg(IRAggExpr::First(row_idx_node)));431agg_storage = aggs.to_vec();432agg_storage.push(ExprIR::from_node(first_agg_node, expr_arena));433&agg_storage434} else {435aggs436};437438let all_independent = keys439.iter()440.chain(aggs.iter())441.all(|expr| is_input_independent(expr.node(), expr_arena, expr_cache));442if all_independent {443return None;444}445446// Fill all expressions into the merger, letting us extract common subexpressions later.447let mut expr_merger = NaiveExprMerger::default();448for key in keys {449expr_merger.add_expr(key.node(), expr_arena);450}451for agg in aggs {452expr_merger.add_expr(agg.node(), expr_arena);453}454455// Extract aggregates, input expressions for those aggregates and replace456// with agg node output columns.457let mut uniq_input_exprs = PlIndexMap::new();458let mut trans_agg_exprs = Vec::new();459let mut trans_keys = Vec::new();460let mut trans_output_exprs = Vec::new();461for key in keys {462let key_id = expr_merger.get_uniq_id(key.node()).unwrap();463let uniq_col = uniq_input_exprs464.entry(key_id)465.or_insert_with(unique_column_name)466.clone();467468// Keys might refer to the same column multiple times, we have to give a unique name to it.469let uniq_name = unique_column_name();470let trans_key_node = expr_arena.add(AExpr::Column(uniq_col));471trans_keys.push(ExprIR::new(472trans_key_node,473OutputName::Alias(uniq_name.clone()),474));475let output_name = OutputName::Alias(key.output_name().clone());476let trans_output_node = expr_arena.add(AExpr::Column(uniq_name));477trans_output_exprs.push(ExprIR::new(trans_output_node, output_name));478}479480let mut uniq_agg_exprs = PlIndexMap::new();481for agg in aggs {482let trans_node = try_lower_elementwise_scalar_agg_expr(483agg.node(),484Some(agg.output_name().clone()),485&expr_merger,486expr_cache,487expr_arena,488&mut trans_agg_exprs,489&mut uniq_input_exprs,490&mut uniq_agg_exprs,491)?;492let output_name = OutputName::Alias(agg.output_name().clone());493trans_output_exprs.push(ExprIR::new(trans_node, output_name));494}495496// We must lower the keys together with the input to the aggregations.497let mut input_exprs = Vec::new();498for (uniq_id, name) in uniq_input_exprs.iter() {499let node = expr_merger.get_node(*uniq_id).unwrap();500input_exprs.push(ExprIR::new(node, OutputName::Alias(name.clone())));501}502503// If all inputs are input independent add a dummy column so the group sizes are correct. See #23868.504if input_exprs505.iter()506.all(|e| is_input_independent(e.node(), expr_arena, expr_cache))507{508let dummy_col_name = phys_sm[input.node].output_schema.get_at_index(0).unwrap().0;509let dummy_col = expr_arena.add(AExpr::Column(dummy_col_name.clone()));510input_exprs.push(ExprIR::new(511dummy_col,512OutputName::ColumnLhs(dummy_col_name.clone()),513));514}515516let pre_select =517build_select_stream(input, &input_exprs, expr_arena, phys_sm, expr_cache, ctx).ok()?;518519let input_schema = &phys_sm[pre_select.node].output_schema;520let group_by_output_schema = compute_output_schema(521input_schema,522&[trans_keys.as_slice(), trans_agg_exprs.as_slice()].concat(),523expr_arena,524)525.unwrap();526let agg_node = phys_sm.insert(PhysNode::new(527group_by_output_schema.clone(),528PhysNodeKind::GroupBy {529input: pre_select,530key: trans_keys,531aggs: trans_agg_exprs,532},533));534535// Sort the input based on the first row index if maintaining order.536let post_select_input = if maintain_order {537let sort_node = phys_sm.insert(PhysNode::new(538group_by_output_schema,539PhysNodeKind::Sort {540input: PhysStream::first(agg_node),541by_column: vec![ExprIR::from_node(row_idx_node, expr_arena)],542slice: None,543sort_options: SortMultipleOptions::new(),544},545));546trans_output_exprs.pop(); // Remove row idx from post-select.547PhysStream::first(sort_node)548} else {549PhysStream::first(agg_node)550};551552let post_select = build_select_stream(553post_select_input,554&trans_output_exprs,555expr_arena,556phys_sm,557expr_cache,558ctx,559);560561let out = if let Some((offset, len)) = options.slice {562post_select.map(|s| build_slice_stream(s, offset, len, phys_sm))563} else {564post_select565};566Some(out)567}568569#[allow(clippy::too_many_arguments)]570pub fn build_group_by_stream(571input: PhysStream,572keys: &[ExprIR],573aggs: &[ExprIR],574output_schema: Arc<Schema>,575maintain_order: bool,576options: Arc<GroupbyOptions>,577apply: Option<PlanCallback<DataFrame, DataFrame>>,578expr_arena: &mut Arena<AExpr>,579phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,580expr_cache: &mut ExprCache,581ctx: StreamingLowerIRContext,582) -> PolarsResult<PhysStream> {583let streaming = try_build_streaming_group_by(584input,585keys,586aggs,587maintain_order,588options.clone(),589apply.clone(),590expr_arena,591phys_sm,592expr_cache,593ctx,594);595if let Some(stream) = streaming {596stream597} else {598let format_str = ctx.prepare_visualization.then(|| {599let mut buffer = String::new();600write_group_by(601&mut buffer,6020,603expr_arena,604keys,605aggs,606apply.as_ref(),607maintain_order,608)609.unwrap();610buffer611});612build_group_by_fallback(613input,614keys,615aggs,616output_schema,617maintain_order,618options,619apply,620expr_arena,621phys_sm,622format_str,623)624}625}626627628