Path: blob/main/crates/polars-mem-engine/src/executors/group_by_streaming.rs
8430 views
use std::borrow::Cow;1use std::sync::Arc;23use polars_core::frame::DataFrame;4#[cfg(feature = "dtype-categorical")]5use polars_core::prelude::DataType;6use polars_core::prelude::{Column, GroupsType};7use polars_core::schema::{Schema, SchemaRef};8use polars_core::series::IsSorted;9use polars_error::PolarsResult;10use polars_expr::prelude::PhysicalExpr;11use polars_expr::state::ExecutionState;12use polars_plan::plans::{AExpr, IR, IRPlan};13use polars_utils::arena::{Arena, Node};1415use super::{Executor, check_expand_literals, group_by_helper};16use crate::StreamingExecutorBuilder;1718pub struct GroupByStreamingExec {19input_exec: Box<dyn Executor>,20input_scan_node: Node,21plan: IRPlan,22builder: StreamingExecutorBuilder,2324phys_keys: Vec<Arc<dyn PhysicalExpr>>,25phys_aggs: Vec<Arc<dyn PhysicalExpr>>,26maintain_order: bool,27output_schema: SchemaRef,28slice: Option<(i64, usize)>,29from_partitioned_ds: bool,30}3132impl GroupByStreamingExec {33#[expect(clippy::too_many_arguments)]34pub fn new(35input: Box<dyn Executor>,36builder: StreamingExecutorBuilder,37root: Node,38lp_arena: &mut Arena<IR>,39expr_arena: &Arena<AExpr>,4041phys_keys: Vec<Arc<dyn PhysicalExpr>>,42phys_aggs: Vec<Arc<dyn PhysicalExpr>>,43maintain_order: bool,44output_schema: SchemaRef,45slice: Option<(i64, usize)>,46from_partitioned_ds: bool,47) -> Self {48// Create a DataFrame scan for injecting the input result49let scan = lp_arena.add(IR::DataFrameScan {50df: Arc::new(DataFrame::empty()),51schema: Arc::new(Schema::default()),52output_schema: None,53});5455let IR::GroupBy {56input: gb_input, ..57} = lp_arena.get_mut(root)58else {59unreachable!();60};6162// Set the scan as the group by input63*gb_input = scan;6465// Prune the subplan into separate arenas66let mut new_ir_arena = Arena::new();67let mut new_expr_arena = Arena::new();68let [new_root, new_scan] = polars_plan::plans::prune::prune(69&[root, scan],70lp_arena,71expr_arena,72&mut new_ir_arena,73&mut new_expr_arena,74)75.try_into()76.unwrap();7778let plan = IRPlan {79lp_top: new_root,80lp_arena: new_ir_arena,81expr_arena: new_expr_arena,82};8384Self {85input_exec: input,86input_scan_node: new_scan,87plan,88builder,89phys_keys,90phys_aggs,91maintain_order,92output_schema,93slice,94from_partitioned_ds,95}96}9798fn keys(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Vec<Column>> {99compute_keys(&self.phys_keys, df, state)100}101}102103fn compute_keys(104keys: &[Arc<dyn PhysicalExpr>],105df: &DataFrame,106state: &ExecutionState,107) -> PolarsResult<Vec<Column>> {108let evaluated = keys109.iter()110.map(|s| s.evaluate(df, state))111.collect::<PolarsResult<_>>()?;112let df = check_expand_literals(df, keys, evaluated, false, Default::default())?;113Ok(df.into_columns())114}115116fn estimate_unique_count(keys: &[Column], mut sample_size: usize) -> PolarsResult<usize> {117// https://stats.stackexchange.com/a/19090/147321118// estimated unique size119// u + ui / m (s - m)120// s: set_size121// m: sample_size122// u: total unique groups counted in sample123// ui: groups with single unique value counted in sample124let set_size = keys[0].len();125if set_size < sample_size {126sample_size = set_size;127}128129let finish = |groups: &GroupsType| {130let u = groups.len() as f64;131let ui = if groups.len() == sample_size {132u133} else {134groups.iter().filter(|g| g.len() == 1).count() as f64135};136137(u + (ui / sample_size as f64) * (set_size - sample_size) as f64) as usize138};139140if keys.len() == 1 {141// we sample as that will work also with sorted data.142// not that sampling without replacement is *very* expensive. don't do that.143let s = keys[0].sample_n(sample_size, true, false, None).unwrap();144// fast multi-threaded way to get unique.145let groups = s.as_materialized_series().group_tuples(true, false)?;146Ok(finish(&groups))147} else {148let offset = (keys[0].len() / 2) as i64;149let df = unsafe { DataFrame::new_unchecked_infer_height(keys.to_vec()) };150let df = df.slice(offset, sample_size);151let names = df.get_column_names().into_iter().cloned();152let gb = df.group_by(names).unwrap();153Ok(finish(gb.get_groups()))154}155}156157// Lower this at debug builds so that we hit this in the test suite.158#[cfg(debug_assertions)]159const PARTITION_LIMIT: usize = 15;160#[cfg(not(debug_assertions))]161const PARTITION_LIMIT: usize = 1000;162163// Checks if we should run normal or default aggregation164// by sampling data.165fn can_run_partitioned(166keys: &[Column],167original_df: &DataFrame,168state: &ExecutionState,169from_partitioned_ds: bool,170) -> PolarsResult<bool> {171if !keys172.iter()173.take(1)174.all(|s| matches!(s.is_sorted_flag(), IsSorted::Not))175{176if state.verbose() {177eprintln!("FOUND SORTED KEY: running default HASH AGGREGATION")178}179Ok(false)180} else if std::env::var("POLARS_NO_PARTITION").is_ok() {181if state.verbose() {182eprintln!("POLARS_NO_PARTITION set: running default HASH AGGREGATION")183}184Ok(false)185} else if std::env::var("POLARS_FORCE_PARTITION").is_ok() {186if state.verbose() {187eprintln!("POLARS_FORCE_PARTITION set: running partitioned HASH AGGREGATION")188}189Ok(true)190} else if original_df.height() < PARTITION_LIMIT && !cfg!(test) {191if state.verbose() {192eprintln!("DATAFRAME < {PARTITION_LIMIT} rows: running default HASH AGGREGATION")193}194Ok(false)195} else {196// below this boundary we assume the partitioned group_by will be faster197let unique_count_boundary = std::env::var("POLARS_PARTITION_UNIQUE_COUNT")198.map(|s| s.parse::<usize>().unwrap())199.unwrap_or(1000);200201let (unique_estimate, sampled_method) = match (keys.len(), keys[0].dtype()) {202#[cfg(feature = "dtype-categorical")]203(1, DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) => {204(mapping.num_cats_upper_bound(), "known")205},206_ => {207// sqrt(N) is a good sample size as it remains low on large numbers208// it is better than taking a fraction as it saturates209let sample_size = (original_df.height() as f64).powf(0.5) as usize;210211// we never sample less than 100 data points.212let sample_size = std::cmp::max(100, sample_size);213(estimate_unique_count(keys, sample_size)?, "estimated")214},215};216if state.verbose() {217eprintln!("{sampled_method} unique values: {unique_estimate}");218}219220if from_partitioned_ds {221let estimated_cardinality = unique_estimate as f32 / original_df.height() as f32;222if estimated_cardinality < 0.4 {223if state.verbose() {224eprintln!("PARTITIONED DS");225}226Ok(true)227} else {228if state.verbose() {229eprintln!(230"PARTITIONED DS: estimated cardinality: {estimated_cardinality} exceeded the boundary: 0.4, running default HASH AGGREGATION"231);232}233Ok(false)234}235} else if unique_estimate > unique_count_boundary {236if state.verbose() {237eprintln!(238"estimated unique count: {unique_estimate} exceeded the boundary: {unique_count_boundary}, running default HASH AGGREGATION"239)240}241Ok(false)242} else {243Ok(true)244}245}246}247248impl Executor for GroupByStreamingExec {249fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult<DataFrame> {250let name = "streaming_group_by";251state.should_stop()?;252#[cfg(debug_assertions)]253{254if state.verbose() {255eprintln!("run {name}")256}257}258let input_df = self.input_exec.execute(state)?;259260let profile_name = if state.has_node_timer() {261Cow::Owned(format!(".{name}()"))262} else {263Cow::Borrowed("")264};265266let keys = self.keys(&input_df, state)?;267268if !can_run_partitioned(&keys, &input_df, state, self.from_partitioned_ds)? {269return group_by_helper(270input_df,271keys,272&self.phys_aggs,273None,274state,275self.maintain_order,276&self.output_schema,277self.slice,278);279}280281// Insert the input DataFrame into our DataFrame scan node282if let IR::DataFrameScan { df, schema, .. } =283self.plan.lp_arena.get_mut(self.input_scan_node)284{285*schema = input_df.schema().clone();286*df = Arc::new(input_df);287} else {288unreachable!();289}290291let mut streaming_exec = (self.builder)(292self.plan.lp_top,293&mut self.plan.lp_arena,294&mut self.plan.expr_arena,295)?;296297state298.clone()299.record(|| streaming_exec.execute(state), profile_name)300}301}302303304