Path: blob/main/crates/polars-stream/src/physical_plan/lower_group_by.rs
8503 views
use std::sync::Arc;12use parking_lot::Mutex;3use polars_core::frame::DataFrame;4use polars_core::prelude::{Field, InitHashMaps, PlIndexMap, PlIndexSet, SortMultipleOptions};5use polars_core::schema::Schema;6use polars_error::{PolarsResult, polars_err};7use polars_expr::state::ExecutionState;8use polars_mem_engine::create_physical_plan;9use polars_plan::plans::expr_ir::{ExprIR, OutputName};10use polars_plan::plans::{AExpr, IR, IRAggExpr, IRFunctionExpr, NaiveExprMerger, write_group_by};11use polars_plan::prelude::{GroupbyOptions, *};12use polars_utils::arena::{Arena, Node};13use polars_utils::pl_str::PlSmallStr;14use polars_utils::{IdxSize, unique_column_name};15use recursive::recursive;16use slotmap::SlotMap;1718use super::{ExprCache, PhysNode, PhysNodeKey, PhysNodeKind, PhysStream, StreamingLowerIRContext};19use crate::physical_plan::lower_expr::{20build_hstack_stream, build_select_stream, compute_output_schema, is_elementwise_rec_cached,21is_fake_elementwise_function, is_input_independent,22};23use crate::physical_plan::lower_ir::{24build_filter_stream, build_row_idx_stream, build_slice_stream,25};26use crate::utils::late_materialized_df::LateMaterializedDataFrame;2728#[allow(clippy::too_many_arguments)]29fn build_group_by_fallback(30input: PhysStream,31keys: &[ExprIR],32aggs: &[ExprIR],33output_schema: Arc<Schema>,34maintain_order: bool,35options: Arc<GroupbyOptions>,36apply: Option<PlanCallback<DataFrame, DataFrame>>,37expr_arena: &mut Arena<AExpr>,38phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,39format_str: Option<String>,40) -> PolarsResult<PhysStream> {41let input_schema = phys_sm[input.node].output_schema.clone();42let lmdf = Arc::new(LateMaterializedDataFrame::default());43let mut lp_arena = Arena::default();44let input_lp_node = lp_arena.add(lmdf.clone().as_ir_node(input_schema));45let group_by_lp_node = lp_arena.add(IR::GroupBy {46input: input_lp_node,47keys: keys.to_vec(),48aggs: aggs.to_vec(),49schema: output_schema.clone(),50maintain_order,51options,52apply,53});54let executor = Mutex::new(create_physical_plan(55group_by_lp_node,56&mut lp_arena,57expr_arena,58Some(crate::dispatch::build_streaming_query_executor),59)?);6061let group_by_node = PhysNode {62output_schema,63kind: PhysNodeKind::InMemoryMap {64input,65map: Arc::new(move |df| {66lmdf.set_materialized_dataframe(df);67let mut state = ExecutionState::new();68executor.lock().execute(&mut state)69}),70format_str,71},72};7374Ok(PhysStream::first(phys_sm.insert(group_by_node)))75}7677// Given an aggregate expression returns a column expression which is to78// represent the aggregate result in the post-select.79//80// For each input to this aggregate uniq_input_names is updated to map the81// unique id of the input expressions to an input columns the aggregate82// expression expects.83//84// uniq_agg_exprs is updated with the unique id of the aggregate mapping to85// the aggregate expression and vector of unique input ids for that aggregate.86#[allow(clippy::too_many_arguments)]87fn replace_agg_uniq(88expr: Node,89expr_merger: &mut NaiveExprMerger,90expr_cache: &mut ExprCache,91expr_arena: &mut Arena<AExpr>,92agg_exprs: &mut Vec<ExprIR>,93uniq_input_names: &mut PlIndexMap<u32, PlSmallStr>,94uniq_agg_exprs: &mut PlIndexMap<u32, (ExprIR, Vec<u32>)>,95uniq_elementwise_exprs: &mut PlIndexMap<u32, ExprIR>,96) -> Node {97let aexpr = expr_arena.get(expr).clone();98let mut inputs = Vec::new();99aexpr.inputs_rev(&mut inputs);100inputs.reverse();101102let agg_id = expr_merger.get_uniq_id(expr).unwrap();103let name = uniq_agg_exprs104.entry(agg_id)105.or_insert_with(|| {106let mut input_ids = Vec::new();107let input_cols = inputs108.iter()109.map(|input| {110let (input_id, node) = replace_elementwise_components(111*input,112expr_merger,113expr_cache,114expr_arena,115uniq_input_names,116uniq_elementwise_exprs,117);118if let Some(id) = input_id {119// Already elementwise.120input_ids.push(id);121node122} else {123let input_id = expr_merger.add_and_get_uniq_id(node, expr_arena);124input_ids.push(input_id);125let input_col = uniq_input_names126.entry(input_id)127.or_insert_with(unique_column_name)128.clone();129expr_arena.add(AExpr::Column(input_col))130}131})132.collect::<Vec<_>>();133let trans_agg_node = expr_arena.add(aexpr.replace_inputs(&input_cols));134135// Add to aggregation expressions and replace with a reference to its output.136let agg_expr = ExprIR::new(trans_agg_node, OutputName::Alias(unique_column_name()));137agg_exprs.push(agg_expr.clone());138(agg_expr, input_ids)139})140.0141.output_name()142.clone();143expr_arena.add(AExpr::Column(name))144}145146/// Replaces all elementwise subexpressions with column references, storing the elementwise147/// expressions uniquely in expr_merger/uniq_elementwise_exprs keys.148#[recursive]149fn replace_elementwise_components(150expr: Node,151expr_merger: &mut NaiveExprMerger,152expr_cache: &mut ExprCache,153expr_arena: &mut Arena<AExpr>,154uniq_input_names: &mut PlIndexMap<u32, PlSmallStr>,155uniq_elementwise_exprs: &mut PlIndexMap<u32, ExprIR>,156) -> (Option<u32>, Node) {157if is_elementwise_rec_cached(expr, expr_arena, expr_cache)158|| (is_input_independent(expr, expr_arena, expr_cache) && is_scalar_ae(expr, expr_arena))159{160let id = expr_merger.add_and_get_uniq_id(expr, expr_arena);161let name = uniq_input_names162.entry(id)163.or_insert_with(unique_column_name)164.clone();165let node = uniq_elementwise_exprs166.entry(id)167.or_insert_with(|| ExprIR::from_column_name(name, expr_arena))168.node();169(Some(id), node)170} else {171let aexpr = expr_arena.get(expr).clone();172let mut inputs = Vec::new();173aexpr.inputs_rev(&mut inputs);174inputs.reverse();175176for input in &mut inputs {177*input = replace_elementwise_components(178*input,179expr_merger,180expr_cache,181expr_arena,182uniq_input_names,183uniq_elementwise_exprs,184)185.1;186}187let rec_node = expr_arena.add(aexpr.replace_inputs(&inputs));188(None, rec_node)189}190}191192/// Tries to lower an expression as a 'elementwise scalar agg expression'.193///194/// Such an expression is defined as the elementwise combination of scalar195/// aggregations.196#[recursive]197#[allow(clippy::too_many_arguments)]198fn try_lower_elementwise_scalar_agg_expr(199expr: Node,200expr_merger: &mut NaiveExprMerger,201expr_cache: &mut ExprCache,202expr_arena: &mut Arena<AExpr>,203agg_exprs: &mut Vec<ExprIR>,204uniq_input_names: &mut PlIndexMap<u32, PlSmallStr>,205uniq_agg_exprs: &mut PlIndexMap<u32, (ExprIR, Vec<u32>)>,206uniq_elementwise_exprs: &mut PlIndexMap<u32, ExprIR>,207) -> Option<Node> {208// Helper macros to simplify (recursive) calls.209macro_rules! lower_rec {210($input:expr) => {211try_lower_elementwise_scalar_agg_expr(212$input,213expr_merger,214expr_cache,215expr_arena,216agg_exprs,217uniq_input_names,218uniq_agg_exprs,219uniq_elementwise_exprs,220)221};222}223224macro_rules! replace_agg_uniq {225($input:expr) => {226replace_agg_uniq(227$input,228expr_merger,229expr_cache,230expr_arena,231agg_exprs,232uniq_input_names,233uniq_agg_exprs,234uniq_elementwise_exprs,235)236};237}238239if is_input_independent(expr, expr_arena, expr_cache) {240if expr_arena.get(expr).is_scalar(expr_arena) {241return Some(expr);242} else {243let agg = IRAggExpr::Implode(expr);244return Some(expr_arena.add(AExpr::Agg(agg)));245}246}247248match expr_arena.get(expr) {249// Should be handled separately in `Eval`.250AExpr::Element => unreachable!(),251252AExpr::StructField(_) => {253// Reflecting StructEval expr state is not yet supported.254None255},256257AExpr::Column(_) => {258// Implicit implode not yet supported.259None260},261262AExpr::Literal(lit) => {263if lit.is_scalar() {264Some(expr)265} else {266None267}268},269270#[cfg(feature = "dynamic_group_by")]271AExpr::Rolling { .. } => None,272273AExpr::Slice { .. }274| AExpr::Over { .. }275| AExpr::Sort { .. }276| AExpr::SortBy { .. }277| AExpr::Gather { .. } => None,278279// Explode and filter are row-separable and should thus in theory work280// in a streaming fashion but they change the length of the input which281// means the same filter/explode should also be applied to the key282// column, which is not (yet) supported.283AExpr::Explode { .. } | AExpr::Filter { .. } => None,284285AExpr::BinaryExpr { left, op, right } => {286let (left, op, right) = (*left, *op, *right);287let left = lower_rec!(left)?;288let right = lower_rec!(right)?;289Some(expr_arena.add(AExpr::BinaryExpr { left, op, right }))290},291292AExpr::Eval {293expr,294evaluation,295variant,296} => {297let (expr, evaluation, variant) = (*expr, *evaluation, *variant);298let expr = lower_rec!(expr)?;299Some(expr_arena.add(AExpr::Eval {300expr,301evaluation,302variant,303}))304},305306AExpr::StructEval { expr, evaluation } => {307// @TODO: Reflect the lowering result of `expr` into the respective308// StructField lowering calls.309let (expr, evaluation) = (*expr, evaluation.clone());310let expr = lower_rec!(expr)?;311312let new_evaluation = evaluation313.into_iter()314.map(|i| {315let new_node = lower_rec!(i.node())?;316Some(ExprIR::new(317new_node,318OutputName::Alias(i.output_name().clone()),319))320})321.collect::<Option<Vec<_>>>()?;322323Some(expr_arena.add(AExpr::StructEval {324expr,325evaluation: new_evaluation,326}))327},328329AExpr::Ternary {330predicate,331truthy,332falsy,333} => {334let (predicate, truthy, falsy) = (*predicate, *truthy, *falsy);335let predicate = lower_rec!(predicate)?;336let truthy = lower_rec!(truthy)?;337let falsy = lower_rec!(falsy)?;338Some(expr_arena.add(AExpr::Ternary {339predicate,340truthy,341falsy,342}))343},344345#[cfg(feature = "bitwise")]346AExpr::Function {347function:348IRFunctionExpr::Bitwise(349IRBitwiseFunction::And | IRBitwiseFunction::Or | IRBitwiseFunction::Xor,350),351..352} => Some(replace_agg_uniq!(expr)),353354#[cfg(feature = "approx_unique")]355AExpr::Function {356function: IRFunctionExpr::ApproxNUnique,357..358} => Some(replace_agg_uniq!(expr)),359360AExpr::Function {361function:362IRFunctionExpr::Boolean(IRBooleanFunction::Any { .. } | IRBooleanFunction::All { .. })363| IRFunctionExpr::MinBy364| IRFunctionExpr::MaxBy365| IRFunctionExpr::NullCount,366..367} => Some(replace_agg_uniq!(expr)),368369AExpr::AnonymousAgg { .. } => Some(replace_agg_uniq!(expr)),370371node @ AExpr::Function { input, options, .. }372| node @ AExpr::AnonymousFunction { input, options, .. }373if options.is_elementwise() && !is_fake_elementwise_function(node) =>374{375let node = node.clone();376let input = input.clone();377let new_input = input378.into_iter()379.map(|i| {380// The function may be sensitive to names (e.g. pl.struct), so we restore them.381let new_node = lower_rec!(i.node())?;382Some(ExprIR::new(383new_node,384OutputName::Alias(i.output_name().clone()),385))386})387.collect::<Option<Vec<_>>>()?;388389let mut new_node = node;390match &mut new_node {391AExpr::Function { input, .. } | AExpr::AnonymousFunction { input, .. } => {392*input = new_input;393},394_ => unreachable!(),395}396Some(expr_arena.add(new_node))397},398399AExpr::Function { .. } | AExpr::AnonymousFunction { .. } => None,400401AExpr::Cast {402expr,403dtype,404options,405} => {406let (expr, dtype, options) = (*expr, dtype.clone(), *options);407let expr = lower_rec!(expr)?;408Some(expr_arena.add(AExpr::Cast {409expr,410dtype,411options,412}))413},414415AExpr::Agg(agg) => {416match agg {417IRAggExpr::Min { .. }418| IRAggExpr::Max { .. }419| IRAggExpr::First(_)420| IRAggExpr::FirstNonNull(_)421| IRAggExpr::Last(_)422| IRAggExpr::LastNonNull(_)423| IRAggExpr::Item { .. }424| IRAggExpr::Mean(_)425| IRAggExpr::Sum(_)426| IRAggExpr::Var(..)427| IRAggExpr::Std(..)428| IRAggExpr::Count { .. } => Some(replace_agg_uniq!(expr)),429IRAggExpr::NUnique(uniq_input) => {430let function = IRFunctionExpr::Unique(false);431let uniq_input_expr = ExprIR::from_node(*uniq_input, expr_arena);432let uniq_node = expr_arena.add(AExpr::Function {433input: vec![uniq_input_expr],434options: function.function_options(),435function,436});437438let count = IRAggExpr::Count {439input: uniq_node,440include_nulls: true,441};442let count_node = expr_arena.add(AExpr::Agg(count));443expr_merger.add_expr(count_node, expr_arena);444Some(replace_agg_uniq!(count_node))445},446IRAggExpr::Median(..)447| IRAggExpr::Implode(..)448| IRAggExpr::Quantile { .. }449| IRAggExpr::AggGroups(..) => None, // TODO: allow all aggregates,450}451},452AExpr::Len => {453let agg_id = expr_merger.get_uniq_id(expr).unwrap();454let name = uniq_agg_exprs455.entry(agg_id)456.or_insert_with(|| {457let agg_expr = ExprIR::new(expr, OutputName::Alias(unique_column_name()));458agg_exprs.push(agg_expr.clone());459(agg_expr, Vec::new())460})461.0462.output_name()463.clone();464Some(expr_arena.add(AExpr::Column(name)))465},466}467}468469#[allow(clippy::too_many_arguments)]470fn try_lower_agg_input_expr(471input_stream: PhysStream,472keys: &[ExprIR],473expr: Node,474expr_arena: &mut Arena<AExpr>,475phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,476expr_cache: &mut ExprCache,477ctx: StreamingLowerIRContext,478) -> PolarsResult<Option<(PhysStream, Node, /* all_keys_included */ bool)>> {479if is_elementwise_rec_cached(expr, expr_arena, expr_cache) {480return Ok(Some((input_stream, expr, true)));481}482483match expr_arena.get(expr) {484AExpr::Function {485input: uniq_input,486function: IRFunctionExpr::Unique(stable),487options: _,488} => {489assert!(uniq_input.len() == 1);490let input_node = uniq_input[0].node();491let maintain_order = *stable;492493let Some((stream, node, all_keys_included)) = try_lower_agg_input_expr(494input_stream,495keys,496input_node,497expr_arena,498phys_sm,499expr_cache,500ctx,501)?502else {503return Ok(None);504};505506let output_name = unique_column_name();507let mut gb_keys = keys.to_vec();508gb_keys.push(ExprIR::new(node, OutputName::Alias(output_name.clone())));509510let aggs = &[];511let options = Arc::new(GroupbyOptions::default());512let Some(stream) = try_build_streaming_group_by(513stream,514&gb_keys,515aggs,516maintain_order,517options,518None,519expr_arena,520phys_sm,521expr_cache,522ctx,523)?524else {525return Ok(None);526};527528let trans_output = expr_arena.add(AExpr::Column(output_name));529Ok(Some((stream, trans_output, all_keys_included)))530},531532AExpr::Filter {533input: filter_input,534by: predicate,535} => {536if !is_elementwise_rec_cached(*filter_input, expr_arena, expr_cache)537|| !is_elementwise_rec_cached(*predicate, expr_arena, expr_cache)538{539return Ok(None);540}541542let output_name = unique_column_name();543let predicate_name = unique_column_name();544let mut select_exprs = keys.to_vec();545select_exprs.push(ExprIR::new(546*filter_input,547OutputName::Alias(output_name.clone()),548));549select_exprs.push(ExprIR::new(550*predicate,551OutputName::Alias(predicate_name.clone()),552));553554let mut stream = build_select_stream(555input_stream,556&select_exprs,557expr_arena,558phys_sm,559expr_cache,560ctx,561)?;562stream = build_filter_stream(563stream,564ExprIR::from_column_name(predicate_name, expr_arena),565expr_arena,566phys_sm,567expr_cache,568ctx,569)?;570571let trans_output = expr_arena.add(AExpr::Column(output_name));572Ok(Some((stream, trans_output, false)))573},574_ => Ok(None),575}576}577578#[allow(clippy::too_many_arguments)]579fn try_build_streaming_group_by(580mut input: PhysStream,581keys: &[ExprIR],582aggs: &[ExprIR],583maintain_order: bool,584options: Arc<GroupbyOptions>,585apply: Option<PlanCallback<DataFrame, DataFrame>>,586expr_arena: &mut Arena<AExpr>,587phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,588expr_cache: &mut ExprCache,589ctx: StreamingLowerIRContext,590) -> PolarsResult<Option<PhysStream>> {591if apply.is_some() {592return Ok(None); // TODO593}594595#[cfg(feature = "dynamic_group_by")]596if options.dynamic.is_some() || options.rolling.is_some() {597return Ok(None); // TODO598}599600if keys.is_empty() {601return Err(602polars_err!(ComputeError: "at least one key is required in a group_by operation"),603);604}605606// Not supported yet.607let all_independent = keys608.iter()609.chain(aggs.iter())610.all(|expr| is_input_independent(expr.node(), expr_arena, expr_cache));611if all_independent {612return Ok(None);613}614615// Augment with row index if maintaining order.616let row_idx_name = unique_column_name();617let row_idx_node = expr_arena.add(AExpr::Column(row_idx_name.clone()));618let mut agg_storage;619let aggs = if maintain_order {620input = build_row_idx_stream(input, row_idx_name.clone(), None, phys_sm);621let first_agg_node = expr_arena.add(AExpr::Agg(IRAggExpr::First(row_idx_node)));622agg_storage = aggs.to_vec();623agg_storage.push(ExprIR::from_node(first_agg_node, expr_arena));624&agg_storage625} else {626aggs627};628629// Fill all expressions into the merger, letting us extract common subexpressions later.630let mut expr_merger = NaiveExprMerger::default();631for key in keys {632expr_merger.add_expr(key.node(), expr_arena);633}634for agg in aggs {635expr_merger.add_expr(agg.node(), expr_arena);636}637638// Extract aggregates, input expressions for those aggregates and replace639// with agg node output columns.640let mut uniq_input_names = PlIndexMap::new();641let mut key_ids = PlIndexSet::new();642let mut trans_agg_exprs = Vec::new();643let mut trans_keys = Vec::new();644let mut trans_output_exprs = Vec::new();645for key in keys {646let key_id = expr_merger.get_uniq_id(key.node()).unwrap();647key_ids.insert(key_id);648let key_name = uniq_input_names649.entry(key_id)650.or_insert_with(|| {651let key_name = unique_column_name();652trans_keys.push(ExprIR::from_column_name(key_name.clone(), expr_arena));653key_name654})655.clone();656657let output_name = OutputName::Alias(key.output_name().clone());658let trans_output_node = expr_arena.add(AExpr::Column(key_name));659trans_output_exprs.push(ExprIR::new(trans_output_node, output_name));660}661662// Maps aggregation expression ids to output column expression and a vec663// of input expression ids.664let mut uniq_agg_exprs = PlIndexMap::new();665666// Maps elementwise input expression ids to column expression.667let mut uniq_elementwise_exprs = PlIndexMap::new();668669for agg in aggs {670let Some(trans_node) = try_lower_elementwise_scalar_agg_expr(671agg.node(),672&mut expr_merger,673expr_cache,674expr_arena,675&mut trans_agg_exprs,676&mut uniq_input_names,677&mut uniq_agg_exprs,678&mut uniq_elementwise_exprs,679) else {680return Ok(None);681};682let output_name = OutputName::Alias(agg.output_name().clone());683trans_output_exprs.push(ExprIR::new(trans_node, output_name));684}685686// We must lower the keys together with the elementwise inputs to the aggregations.687let mut pre_select_input_ids = key_ids.clone();688pre_select_input_ids.extend(uniq_elementwise_exprs.keys());689690let mut pre_select_exprs = Vec::new();691for uniq_id in pre_select_input_ids {692let name = &uniq_input_names[&uniq_id];693let node = expr_merger.get_node(uniq_id).unwrap();694pre_select_exprs.push(ExprIR::new(node, OutputName::Alias(name.clone())));695}696697// If all inputs are input independent add a dummy column so the group sizes are correct. See #23868.698let mut direct_input_needed = false;699if pre_select_exprs700.iter()701.all(|e| is_input_independent(e.node(), expr_arena, expr_cache))702{703direct_input_needed = true;704let dummy_col_name = phys_sm[input.node].output_schema.get_at_index(0).unwrap().0;705let dummy_col = expr_arena.add(AExpr::Column(dummy_col_name.clone()));706pre_select_exprs.push(ExprIR::new(707dummy_col,708OutputName::ColumnLhs(dummy_col_name.clone()),709));710}711712// Create pre-select.713let pre_select = build_select_stream(714input,715&pre_select_exprs,716expr_arena,717phys_sm,718expr_cache,719ctx,720)?;721722// Create input streams.723let mut all_keys_included_in_other_inputs = false;724let mut aggs_with_elementwise_inputs = Vec::new();725let mut other_agg_input_streams = PlIndexMap::new();726for (_uniq_agg_id, (agg_expr, input_ids)) in uniq_agg_exprs.iter() {727if input_ids728.iter()729.all(|i| uniq_elementwise_exprs.contains_key(i))730{731aggs_with_elementwise_inputs.push(agg_expr.clone());732direct_input_needed = true;733continue;734}735736// More than one non-elementwise input to this agg, unsure how to handle this.737if input_ids.len() != 1 {738return Ok(None);739}740741let input_id = input_ids[0];742let input_node = expr_merger.get_node(input_id).unwrap();743let input_name = uniq_input_names[&input_id].clone();744if !other_agg_input_streams.contains_key(&input_id) {745let Some((stream, trans_node, keys_included)) = try_lower_agg_input_expr(746pre_select,747&trans_keys,748input_node,749expr_arena,750phys_sm,751expr_cache,752ctx,753)?754else {755return Ok(None);756};757all_keys_included_in_other_inputs |= keys_included;758let mut trans_stream_outputs = trans_keys.clone();759trans_stream_outputs.push(ExprIR::new(trans_node, OutputName::Alias(input_name)));760let stream = build_select_stream(761stream,762&trans_stream_outputs,763expr_arena,764phys_sm,765expr_cache,766ctx,767)?;768other_agg_input_streams.insert(input_id, (stream, Vec::new()));769}770771other_agg_input_streams[&input_id].1.push(agg_expr.clone());772}773774// Reconstruct the output schema of this node.775let mut group_by_output_schema = Schema::default();776let mut inputs = Vec::new();777let mut key_per_input = Vec::new();778let mut aggs_per_input = Vec::new();779if direct_input_needed || !all_keys_included_in_other_inputs {780let this_input_schema = &phys_sm[pre_select.node].output_schema;781let exprs = [782trans_keys.as_slice(),783aggs_with_elementwise_inputs.as_slice(),784]785.concat();786let elementwise_out_schema =787compute_output_schema(this_input_schema, &exprs, expr_arena).unwrap();788group_by_output_schema.merge((*elementwise_out_schema).clone());789inputs.push(pre_select);790key_per_input.push(trans_keys.clone());791aggs_per_input.push(aggs_with_elementwise_inputs);792}793for (_input_id, (stream, aggs)) in other_agg_input_streams {794let this_input_schema = &phys_sm[stream.node].output_schema;795let exprs = [trans_keys.as_slice(), aggs.as_slice()].concat();796let this_out_schema = compute_output_schema(this_input_schema, &exprs, expr_arena).unwrap();797group_by_output_schema.merge((*this_out_schema).clone());798inputs.push(stream);799key_per_input.push(trans_keys.clone());800aggs_per_input.push(aggs);801}802let group_by_output_schema = Arc::new(group_by_output_schema);803804let agg_node = phys_sm.insert(PhysNode::new(805group_by_output_schema.clone(),806PhysNodeKind::GroupBy {807inputs,808key_per_input,809aggs_per_input,810},811));812813// Sort the input based on the first row index if maintaining order.814let post_select_input = if maintain_order {815let sort_node = phys_sm.insert(PhysNode::new(816group_by_output_schema,817PhysNodeKind::Sort {818input: PhysStream::first(agg_node),819by_column: vec![trans_output_exprs.last().unwrap().clone()],820slice: None,821sort_options: SortMultipleOptions::new(),822},823));824trans_output_exprs.pop(); // Remove row idx from post-select.825PhysStream::first(sort_node)826} else {827PhysStream::first(agg_node)828};829830let post_select = build_select_stream(831post_select_input,832&trans_output_exprs,833expr_arena,834phys_sm,835expr_cache,836ctx,837)?;838839let out = if let Some((offset, len)) = options.slice {840build_slice_stream(post_select, offset, len, phys_sm)841} else {842post_select843};844Ok(Some(out))845}846847#[expect(clippy::too_many_arguments)]848pub fn try_build_sorted_group_by(849input: PhysStream,850keys: &[ExprIR],851aggs: &[ExprIR],852output_schema: Arc<Schema>,853maintain_order: bool,854options: Arc<GroupbyOptions>,855apply: Option<PlanCallback<DataFrame, DataFrame>>,856expr_arena: &mut Arena<AExpr>,857phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,858expr_cache: &mut ExprCache,859ctx: StreamingLowerIRContext,860are_keys_sorted: bool,861) -> PolarsResult<Option<PhysStream>> {862let input_schema = phys_sm[input.node].output_schema.as_ref();863864if keys.is_empty()865|| apply.is_some()866|| options.is_rolling()867|| options.is_dynamic()868|| (!are_keys_sorted && maintain_order)869|| keys.iter().any(|k| {870k.dtype(input_schema, expr_arena)871.is_ok_and(|dtype| dtype.contains_unknown())872})873{874return Ok(None);875}876877let mut input = input;878let mut input_column = unique_column_name();879let mut projected = false;880let mut row_encoded: Option<Vec<Field>> = None;881882if keys.len() > 1 || keys[0].dtype(input_schema, expr_arena)?.is_nested() {883let key_fields = keys884.iter()885.map(|k| k.field(input_schema, expr_arena))886.collect::<PolarsResult<Vec<_>>>()?;887let expr = AExprBuilder::function(888keys.to_vec(),889IRFunctionExpr::RowEncode(890key_fields.iter().map(|k| k.dtype().clone()).collect(),891RowEncodingVariant::Ordered {892descending: None,893nulls_last: None,894broadcast_nulls: None,895},896),897expr_arena,898)899.expr_ir(input_column.clone());900input = build_hstack_stream(input, &[expr], expr_arena, phys_sm, expr_cache, ctx)?;901projected = true;902row_encoded = Some(key_fields);903} else if !matches!(expr_arena.get(keys[0].node()), AExpr::Column(c) if c == keys[0].output_name())904{905input = build_hstack_stream(906input,907&[keys[0].with_alias(input_column.clone())],908expr_arena,909phys_sm,910expr_cache,911ctx,912)?;913projected = true;914} else {915input_column = keys[0].output_name().clone();916}917918let key = AExprBuilder::col(input_column.clone(), expr_arena).expr_ir(input_column.clone());919920let schema = phys_sm[input.node].output_schema.clone();921if !are_keys_sorted {922let row_idx_name = unique_column_name();923input = build_row_idx_stream(input, row_idx_name.clone(), None, phys_sm);924925let row_idx_expr =926AExprBuilder::col(row_idx_name.clone(), expr_arena).expr_ir(row_idx_name.clone());927928input = PhysStream::first(phys_sm.insert(PhysNode {929output_schema: phys_sm[input.node].output_schema.clone(),930kind: PhysNodeKind::Sort {931input,932by_column: vec![key, row_idx_expr],933slice: None,934sort_options: SortMultipleOptions::default(),935},936}));937}938939let mut gb_output_schema = Schema::with_capacity(aggs.len() + 1);940gb_output_schema.insert(941input_column.clone(),942schema.get(input_column.as_str()).unwrap().clone(),943);944for agg in aggs {945let field = agg.field(schema.as_ref(), expr_arena)?;946let dtype = if agg.is_scalar(expr_arena) {947field.dtype948} else {949field.dtype.implode()950};951gb_output_schema.insert(field.name, dtype);952}953input = PhysStream::first(954phys_sm.insert(PhysNode {955output_schema: Arc::new(gb_output_schema.clone()),956kind: PhysNodeKind::SortedGroupBy {957input,958key: input_column.clone(),959aggs: aggs.to_vec(),960slice: options961.slice962.filter(|(o, _)| *o >= 0)963.map(|(o, l)| (o as IdxSize, l as IdxSize)),964},965}),966);967if let Some((offset, length)) = options.slice.as_ref().filter(|(o, _)| *o < 0) {968input = build_slice_stream(input, *offset, *length, phys_sm);969}970971if projected {972if let Some(key_fields) = row_encoded {973let expr =974AExprBuilder::col(input_column.clone(), expr_arena).expr_ir(input_column.clone());975let expr = AExprBuilder::function(976vec![expr],977IRFunctionExpr::RowDecode(978key_fields,979RowEncodingVariant::Ordered {980descending: None,981nulls_last: None,982broadcast_nulls: None,983},984),985expr_arena,986)987.expr_ir(input_column.clone());988input = build_hstack_stream(input, &[expr], expr_arena, phys_sm, expr_cache, ctx)?;989990// Unnest the row encoded columns.991input = PhysStream::first(phys_sm.insert(PhysNode {992output_schema: output_schema.clone(),993kind: PhysNodeKind::Map {994input,995map: Arc::new(move |df: DataFrame| df.unnest([input_column.clone()], None))996as _,997format_str: ctx.prepare_visualization.then(|| "UNNEST".to_string()),998},999}));10001001let exprs = output_schema1002.iter_names()1003.map(|name| AExprBuilder::col(name.clone(), expr_arena).expr_ir(name.clone()))1004.collect::<Vec<_>>();1005input = build_select_stream(input, &exprs, expr_arena, phys_sm, expr_cache, ctx)?;1006} else {1007let exprs = std::iter::once(input_column)1008.map(|name| (name, output_schema.get_at_index(0).unwrap().0.clone()))1009.chain(1010output_schema1011.iter_names_cloned()1012.skip(1)1013.map(|name| (name.clone(), name.clone())),1014)1015.map(|(col_name, out_name)| {1016AExprBuilder::col(col_name, expr_arena).expr_ir(out_name)1017})1018.collect::<Vec<_>>();1019input = build_select_stream(input, &exprs, expr_arena, phys_sm, expr_cache, ctx)?;1020}1021}10221023Ok(Some(input))1024}10251026#[allow(clippy::too_many_arguments)]1027pub fn build_group_by_stream(1028input: PhysStream,1029keys: &[ExprIR],1030aggs: &[ExprIR],1031output_schema: Arc<Schema>,1032maintain_order: bool,1033options: Arc<GroupbyOptions>,1034apply: Option<PlanCallback<DataFrame, DataFrame>>,1035expr_arena: &mut Arena<AExpr>,1036phys_sm: &mut SlotMap<PhysNodeKey, PhysNode>,1037expr_cache: &mut ExprCache,1038ctx: StreamingLowerIRContext,1039are_keys_sorted: bool,1040) -> PolarsResult<PhysStream> {1041#[cfg(feature = "dynamic_group_by")]1042if let Some(rolling_options) = options.as_ref().rolling.as_ref()1043&& keys.is_empty()1044&& apply.is_none()1045{1046let mut input = PhysStream::first(1047phys_sm.insert(PhysNode::new(1048output_schema.clone(),1049PhysNodeKind::RollingGroupBy {1050input,1051index_column: rolling_options.index_column.clone(),1052period: rolling_options.period,1053offset: rolling_options.offset,1054closed: rolling_options.closed_window,1055slice: options1056.slice1057.filter(|(o, _)| *o >= 0)1058.map(|(o, l)| (o as IdxSize, l as IdxSize)),1059aggs: aggs.to_vec(),1060},1061)),1062);1063if let Some((offset, length)) = options.slice.as_ref().filter(|(o, _)| *o < 0) {1064input = build_slice_stream(input, *offset, *length, phys_sm);1065}1066return Ok(input);1067} else if let Some(dynamic_options) = options.as_ref().dynamic.as_ref()1068&& keys.is_empty()1069&& apply.is_none()1070{1071let mut input = PhysStream::first(1072phys_sm.insert(PhysNode::new(1073output_schema.clone(),1074PhysNodeKind::DynamicGroupBy {1075input,1076options: dynamic_options.clone(),1077aggs: aggs.to_vec(),1078slice: options1079.slice1080.filter(|(o, _)| *o >= 0)1081.map(|(o, l)| (o as IdxSize, l as IdxSize)),1082},1083)),1084);1085if let Some((offset, length)) = options.slice.as_ref().filter(|(o, _)| *o < 0) {1086input = build_slice_stream(input, *offset, *length, phys_sm);1087}1088return Ok(input);1089}10901091if (are_keys_sorted || std::env::var("POLARS_FORCE_SORTED_GROUP_BY").is_ok_and(|v| v == "1"))1092&& let Some(stream) = try_build_sorted_group_by(1093input,1094keys,1095aggs,1096output_schema.clone(),1097maintain_order,1098options.clone(),1099apply.clone(),1100expr_arena,1101phys_sm,1102expr_cache,1103ctx,1104are_keys_sorted,1105)?1106{1107Ok(stream)1108} else if let Some(stream) = try_build_streaming_group_by(1109input,1110keys,1111aggs,1112maintain_order,1113options.clone(),1114apply.clone(),1115expr_arena,1116phys_sm,1117expr_cache,1118ctx,1119)? {1120Ok(stream)1121} else {1122let format_str = ctx.prepare_visualization.then(|| {1123let mut buffer = String::new();1124write_group_by(1125&mut buffer,11260,1127expr_arena,1128keys,1129aggs,1130apply.as_ref(),1131maintain_order,1132)1133.unwrap();1134buffer1135});1136build_group_by_fallback(1137input,1138keys,1139aggs,1140output_schema,1141maintain_order,1142options,1143apply,1144expr_arena,1145phys_sm,1146format_str,1147)1148}1149}115011511152