Path: blob/main/crates/polars-stream/src/nodes/group_by.rs
6939 views
use std::sync::Arc;12use polars_core::POOL;3use polars_core::prelude::{IntoColumn, PlHashSet, PlRandomState};4use polars_core::schema::Schema;5use polars_core::utils::accumulate_dataframes_vertical_unchecked;6use polars_expr::groups::Grouper;7use polars_expr::hash_keys::HashKeys;8use polars_expr::hot_groups::{HotGrouper, new_hash_hot_grouper};9use polars_expr::reduce::GroupedReduction;10use polars_utils::IdxSize;11use polars_utils::cardinality_sketch::CardinalitySketch;12use polars_utils::hashing::HashPartitioner;13use polars_utils::itertools::Itertools;14use polars_utils::pl_str::PlSmallStr;15use polars_utils::sparse_init_vec::SparseInitVec;16use rayon::prelude::*;1718use super::compute_node_prelude::*;19use crate::async_executor;20use crate::async_primitives::connector::Receiver;21use crate::expression::StreamExpr;22use crate::morsel::get_ideal_morsel_size;23use crate::nodes::in_memory_source::InMemorySourceNode;2425#[cfg(debug_assertions)]26const DEFAULT_HOT_TABLE_SIZE: usize = 4;27#[cfg(not(debug_assertions))]28const DEFAULT_HOT_TABLE_SIZE: usize = 4096;2930struct LocalGroupBySinkState {31hot_grouper: Box<dyn HotGrouper>,32hot_grouped_reductions: Vec<Box<dyn GroupedReduction>>,3334// A cardinality sketch per partition for the keys seen by this builder.35sketch_per_p: Vec<CardinalitySketch>,3637// morsel_idxs_values_per_p[p][start..stop] contains the offsets into cold_morsels[i]38// for partition p, where start, stop are:39// let start = morsel_idxs_offsets[i * num_partitions + p];40// let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p];41cold_morsels: Vec<(u64, HashKeys, DataFrame)>,42morsel_idxs_values_per_p: Vec<Vec<IdxSize>>,43morsel_idxs_offsets_per_p: Vec<usize>,4445// Similar to the above, but for (evicted) pre-aggregates.46pre_aggs: Vec<(HashKeys, Vec<Box<dyn GroupedReduction>>)>,47pre_agg_idxs_values_per_p: Vec<Vec<IdxSize>>,48pre_agg_idxs_offsets_per_p: Vec<usize>,49}5051impl LocalGroupBySinkState {52fn new(53key_schema: Arc<Schema>,54reductions: Vec<Box<dyn GroupedReduction>>,55hot_table_size: usize,56num_partitions: usize,57) -> Self {58let hot_grouper = new_hash_hot_grouper(key_schema, hot_table_size);59Self {60hot_grouper,61hot_grouped_reductions: reductions,6263sketch_per_p: vec![CardinalitySketch::new(); num_partitions],6465cold_morsels: Vec::new(),66morsel_idxs_values_per_p: vec![Vec::new(); num_partitions],67morsel_idxs_offsets_per_p: vec![0; num_partitions],6869pre_aggs: Vec::new(),70pre_agg_idxs_values_per_p: vec![Vec::new(); num_partitions],71pre_agg_idxs_offsets_per_p: vec![0; num_partitions],72}73}7475fn flush_evictions(&mut self, partitioner: &HashPartitioner) {76let hash_keys = self.hot_grouper.take_evicted_keys();77let reductions = self78.hot_grouped_reductions79.iter_mut()80.map(|hgr| hgr.take_evictions())81.collect_vec();82self.add_pre_agg(hash_keys, reductions, partitioner);83}8485fn add_pre_agg(86&mut self,87hash_keys: HashKeys,88reductions: Vec<Box<dyn GroupedReduction>>,89partitioner: &HashPartitioner,90) {91hash_keys.gen_idxs_per_partition(92partitioner,93&mut self.pre_agg_idxs_values_per_p,94&mut self.sketch_per_p,95true,96);97self.pre_agg_idxs_offsets_per_p98.extend(self.pre_agg_idxs_values_per_p.iter().map(|vp| vp.len()));99self.pre_aggs.push((hash_keys, reductions));100}101}102103struct GroupBySinkState {104key_selectors: Vec<StreamExpr>,105grouper: Box<dyn Grouper>,106uniq_grouped_reduction_cols: Vec<PlSmallStr>,107grouped_reduction_cols: Vec<PlSmallStr>,108grouped_reductions: Vec<Box<dyn GroupedReduction>>,109locals: Vec<LocalGroupBySinkState>,110random_state: PlRandomState,111partitioner: HashPartitioner,112has_order_sensitive_agg: bool,113}114115impl GroupBySinkState {116fn spawn<'env, 's>(117&'env mut self,118scope: &'s TaskScope<'s, 'env>,119receivers: Vec<Receiver<Morsel>>,120state: &'s StreamingExecutionState,121join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,122) {123for (mut recv, local) in receivers.into_iter().zip(&mut self.locals) {124let key_selectors = &self.key_selectors;125let uniq_grouped_reduction_cols = &self.uniq_grouped_reduction_cols;126let grouped_reduction_cols = &self.grouped_reduction_cols;127let random_state = &self.random_state;128let partitioner = self.partitioner.clone();129let has_order_sensitive_agg = self.has_order_sensitive_agg;130join_handles.push(scope.spawn_task(TaskPriority::High, async move {131let mut hot_idxs = Vec::new();132let mut hot_group_idxs = Vec::new();133let mut cold_idxs = Vec::new();134while let Ok(morsel) = recv.recv().await {135// Compute hot group indices from key.136let seq = morsel.seq().to_u64();137let mut df = morsel.into_df();138let mut key_columns = Vec::new();139for selector in key_selectors {140let s = selector.evaluate(&df, &state.in_memory_exec_state).await?;141key_columns.push(s.into_column());142}143let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?;144let hash_keys = HashKeys::from_df(&keys, *random_state, true, false);145146hot_idxs.clear();147hot_group_idxs.clear();148cold_idxs.clear();149local.hot_grouper.insert_keys(150&hash_keys,151&mut hot_idxs,152&mut hot_group_idxs,153&mut cold_idxs,154has_order_sensitive_agg,155);156157// Drop columns not used for reductions (key-only columns).158if uniq_grouped_reduction_cols.len() < grouped_reduction_cols.len() {159df = df._select_impl(uniq_grouped_reduction_cols).unwrap();160}161df.rechunk_mut(); // For gathers.162163// Update hot reductions.164for (col, reduction) in grouped_reduction_cols165.iter()166.zip(&mut local.hot_grouped_reductions)167{168unsafe {169// SAFETY: we resize the reduction to the number of groups beforehand.170reduction.resize(local.hot_grouper.num_groups());171reduction.update_groups_while_evicting(172df.column(col).unwrap(),173&hot_idxs,174&hot_group_idxs,175seq,176)?;177}178}179180// Store cold keys.181// TODO: don't always gather, if majority cold simply store all and remember offsets into it.182if !cold_idxs.is_empty() {183unsafe {184let cold_keys = hash_keys.gather_unchecked(&cold_idxs);185let cold_df = df.take_slice_unchecked_impl(&cold_idxs, false);186187cold_keys.gen_idxs_per_partition(188&partitioner,189&mut local.morsel_idxs_values_per_p,190&mut local.sketch_per_p,191true,192);193local194.morsel_idxs_offsets_per_p195.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len()));196local.cold_morsels.push((seq, cold_keys, cold_df));197}198}199200// If we have too many evicted rows, flush them.201if local.hot_grouper.num_evictions() >= get_ideal_morsel_size() {202local.flush_evictions(&partitioner);203}204}205Ok(())206}));207}208}209210fn combine_locals(&mut self) -> PolarsResult<Vec<GroupByPartition>> {211// Finalize pre-aggregations.212POOL.install(|| {213self.locals214.as_mut_slice()215.into_par_iter()216.with_max_len(1)217.for_each(|l| {218if l.hot_grouper.num_evictions() > 0 {219l.flush_evictions(&self.partitioner);220}221let hot_keys = l.hot_grouper.keys();222let hot_reductions = core::mem::take(&mut l.hot_grouped_reductions);223l.add_pre_agg(hot_keys, hot_reductions, &self.partitioner);224});225});226227// To reduce maximum memory usage we want to drop the morsels228// as soon as they're processed, so we move into Arcs. The drops might229// also be expensive, so instead of directly dropping we put that on230// a work queue.231let morsels_per_local = self232.locals233.iter_mut()234.map(|l| Arc::new(core::mem::take(&mut l.cold_morsels)))235.collect_vec();236let pre_aggs_per_local = self237.locals238.iter_mut()239.map(|l| Arc::new(core::mem::take(&mut l.pre_aggs)))240.collect_vec();241enum ToDrop<A, B> {242A(A),243B(B),244}245let (drop_q_send, drop_q_recv) = async_channel::bounded(self.locals.len());246let num_partitions = self.locals[0].sketch_per_p.len();247let output_per_partition: SparseInitVec<GroupByPartition> =248SparseInitVec::with_capacity(num_partitions);249let locals = &self.locals;250let grouper_template = &self.grouper;251let grouped_reductions_template = &self.grouped_reductions;252let grouped_reduction_cols = &self.grouped_reduction_cols;253254async_executor::task_scope(|s| {255// Wrap in outer Arc to move to each thread, performing the256// expensive clone on that thread.257let arc_morsels_per_local = Arc::new(morsels_per_local);258let arc_pre_aggs_per_local = Arc::new(pre_aggs_per_local);259let mut join_handles = Vec::new();260for p in 0..num_partitions {261let arc_morsels_per_local = Arc::clone(&arc_morsels_per_local);262let arc_pre_aggs_per_local = Arc::clone(&arc_pre_aggs_per_local);263let drop_q_send = drop_q_send.clone();264let drop_q_recv = drop_q_recv.clone();265let output_per_partition = &output_per_partition;266join_handles.push(s.spawn_task(TaskPriority::High, async move {267// Extract from outer arc and drop outer arc.268let morsels_per_local = Arc::unwrap_or_clone(arc_morsels_per_local);269let pre_aggs_per_local = Arc::unwrap_or_clone(arc_pre_aggs_per_local);270271// Compute cardinality estimate and total amount of272// payload for this partition.273let mut sketch = CardinalitySketch::new();274for l in locals {275sketch.combine(&l.sketch_per_p[p]);276}277278// Allocate grouper and reductions.279let est_num_groups = sketch.estimate() * 5 / 4;280let mut p_grouper = grouper_template.new_empty();281let mut p_reductions = grouped_reductions_template282.iter()283.map(|gr| gr.new_empty())284.collect_vec();285p_grouper.reserve(est_num_groups);286for r in &mut p_reductions {287r.reserve(est_num_groups);288}289290// Insert morsels.291let mut skip_drop_attempt = false;292let mut group_idxs = Vec::new();293for (l, l_morsels) in locals.iter().zip(morsels_per_local) {294// Try to help with dropping.295if !skip_drop_attempt {296drop(drop_q_recv.try_recv());297}298299for (i, morsel) in l_morsels.iter().enumerate() {300let (seq_id, keys, cols) = morsel;301unsafe {302let p_morsel_idxs_start =303l.morsel_idxs_offsets_per_p[i * num_partitions + p];304let p_morsel_idxs_stop =305l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p];306let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]307[p_morsel_idxs_start..p_morsel_idxs_stop];308309group_idxs.clear();310p_grouper.insert_keys_subset(311keys,312p_morsel_idxs,313Some(&mut group_idxs),314);315for (c, r) in grouped_reduction_cols.iter().zip(&mut p_reductions) {316let values = cols.column(c.as_str()).unwrap();317r.resize(p_grouper.num_groups());318r.update_groups_subset(319values,320p_morsel_idxs,321&group_idxs,322*seq_id,323)?;324}325}326}327328if let Some(l) = Arc::into_inner(l_morsels) {329// If we're the last thread to process this set of morsels we're probably330// falling behind the rest, since the drop can be quite expensive we skip331// a drop attempt hoping someone else will pick up the slack.332drop(drop_q_send.try_send(ToDrop::A(l)));333skip_drop_attempt = true;334} else {335skip_drop_attempt = false;336}337}338339// Insert pre-aggregates.340for (l, l_pre_aggs) in locals.iter().zip(pre_aggs_per_local) {341// Try to help with dropping.342if !skip_drop_attempt {343drop(drop_q_recv.try_recv());344}345346for (i, key_pre_aggs) in l_pre_aggs.iter().enumerate() {347let (keys, pre_aggs) = key_pre_aggs;348unsafe {349let p_pre_agg_idxs_start =350l.pre_agg_idxs_offsets_per_p[i * num_partitions + p];351let p_pre_agg_idxs_stop =352l.pre_agg_idxs_offsets_per_p[(i + 1) * num_partitions + p];353let p_pre_agg_idxs = &l.pre_agg_idxs_values_per_p[p]354[p_pre_agg_idxs_start..p_pre_agg_idxs_stop];355356group_idxs.clear();357p_grouper.insert_keys_subset(358keys,359p_pre_agg_idxs,360Some(&mut group_idxs),361);362for (pre_agg, r) in pre_aggs.iter().zip(&mut p_reductions) {363r.resize(p_grouper.num_groups());364r.combine_subset(&**pre_agg, p_pre_agg_idxs, &group_idxs)?;365}366}367}368369if let Some(l) = Arc::into_inner(l_pre_aggs) {370// If we're the last thread to process this set of morsels we're probably371// falling behind the rest, since the drop can be quite expensive we skip372// a drop attempt hoping someone else will pick up the slack.373drop(drop_q_send.try_send(ToDrop::B(l)));374skip_drop_attempt = true;375} else {376skip_drop_attempt = false;377}378}379380// We're done, help others out by doing drops.381drop(drop_q_send); // So we don't deadlock trying to receive from ourselves.382while let Ok(to_drop) = drop_q_recv.recv().await {383drop(to_drop);384}385386output_per_partition387.try_set(388p,389GroupByPartition {390grouper: p_grouper,391grouped_reductions: p_reductions,392},393)394.ok()395.unwrap();396397PolarsResult::Ok(())398}));399}400401// Drop outer arc after spawning each thread so the inner arcs402// can get dropped as soon as they're processed. We also have to403// drop the drop queue sender so we don't deadlock waiting for it404// to end.405drop(arc_morsels_per_local);406drop(arc_pre_aggs_per_local);407drop(drop_q_send);408409polars_io::pl_async::get_runtime().block_on(async move {410for handle in join_handles {411handle.await?;412}413PolarsResult::Ok(())414})?;415PolarsResult::Ok(())416})?;417418// Drop remaining local state in parallel.419POOL.install(|| {420core::mem::take(&mut self.locals)421.into_par_iter()422.with_max_len(1)423.for_each(drop);424});425426Ok(output_per_partition.try_assume_init().ok().unwrap())427}428}429430struct GroupByPartition {431grouper: Box<dyn Grouper>,432grouped_reductions: Vec<Box<dyn GroupedReduction>>,433}434435impl GroupByPartition {436fn into_df(self, key_schema: &Schema, output_schema: &Schema) -> PolarsResult<DataFrame> {437let mut out = self.grouper.get_keys_in_group_order(key_schema);438let out_names = output_schema.iter_names().skip(out.width());439for (mut r, name) in self.grouped_reductions.into_iter().zip(out_names) {440unsafe {441out.with_column_unchecked(r.finalize()?.with_name(name.clone()).into_column());442}443}444Ok(out)445}446}447448enum GroupByState {449Sink(GroupBySinkState),450Source(InMemorySourceNode),451Done,452}453454pub struct GroupByNode {455state: GroupByState,456key_schema: Arc<Schema>,457output_schema: Arc<Schema>,458}459460impl GroupByNode {461#[allow(clippy::too_many_arguments)]462pub fn new(463key_schema: Arc<Schema>,464key_selectors: Vec<StreamExpr>,465grouper: Box<dyn Grouper>,466grouped_reduction_cols: Vec<PlSmallStr>,467grouped_reductions: Vec<Box<dyn GroupedReduction>>,468output_schema: Arc<Schema>,469random_state: PlRandomState,470num_pipelines: usize,471has_order_sensitive_agg: bool,472) -> Self {473let hot_table_size = std::env::var("POLARS_HOT_TABLE_SIZE")474.map(|sz| sz.parse::<usize>().unwrap())475.unwrap_or(DEFAULT_HOT_TABLE_SIZE);476let num_partitions = num_pipelines;477let uniq_grouped_reduction_cols = grouped_reduction_cols478.iter()479.cloned()480.collect::<PlHashSet<_>>()481.into_iter()482.collect_vec();483let locals = (0..num_pipelines)484.map(|_| {485let reductions = grouped_reductions.iter().map(|gr| gr.new_empty()).collect();486LocalGroupBySinkState::new(487key_schema.clone(),488reductions,489hot_table_size,490num_partitions,491)492})493.collect();494let partitioner = HashPartitioner::new(num_partitions, 0);495Self {496state: GroupByState::Sink(GroupBySinkState {497key_selectors,498grouped_reductions,499grouper,500random_state,501uniq_grouped_reduction_cols,502grouped_reduction_cols,503locals,504partitioner,505has_order_sensitive_agg,506}),507key_schema,508output_schema,509}510}511}512513impl ComputeNode for GroupByNode {514fn name(&self) -> &str {515"group-by"516}517518fn update_state(519&mut self,520recv: &mut [PortState],521send: &mut [PortState],522state: &StreamingExecutionState,523) -> PolarsResult<()> {524assert!(recv.len() == 1 && send.len() == 1);525526// State transitions.527match &mut self.state {528// If the output doesn't want any more data, transition to being done.529_ if send[0] == PortState::Done => {530self.state = GroupByState::Done;531},532// Input is done, transition to being a source.533GroupByState::Sink(_) if matches!(recv[0], PortState::Done) => {534let GroupByState::Sink(mut sink) =535core::mem::replace(&mut self.state, GroupByState::Done)536else {537unreachable!()538};539let partitions = sink.combine_locals()?;540let dfs = POOL.install(|| {541partitions542.into_par_iter()543.map(|p| p.into_df(&self.key_schema, &self.output_schema))544.collect::<Result<Vec<_>, _>>()545})?;546547let df = accumulate_dataframes_vertical_unchecked(dfs);548let source = InMemorySourceNode::new(Arc::new(df), MorselSeq::new(0));549self.state = GroupByState::Source(source);550},551// Defer to source node implementation.552GroupByState::Source(src) => {553src.update_state(&mut [], send, state)?;554if send[0] == PortState::Done {555self.state = GroupByState::Done;556}557},558// Nothing to change.559GroupByState::Done | GroupByState::Sink(_) => {},560}561562// Communicate our state.563match &self.state {564GroupByState::Sink { .. } => {565send[0] = PortState::Blocked;566recv[0] = PortState::Ready;567},568GroupByState::Source(..) => {569recv[0] = PortState::Done;570send[0] = PortState::Ready;571},572GroupByState::Done => {573recv[0] = PortState::Done;574send[0] = PortState::Done;575},576}577Ok(())578}579580fn spawn<'env, 's>(581&'env mut self,582scope: &'s TaskScope<'s, 'env>,583recv_ports: &mut [Option<RecvPort<'_>>],584send_ports: &mut [Option<SendPort<'_>>],585state: &'s StreamingExecutionState,586join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,587) {588assert!(send_ports.len() == 1 && recv_ports.len() == 1);589match &mut self.state {590GroupByState::Sink(sink) => {591assert!(send_ports[0].is_none());592sink.spawn(593scope,594recv_ports[0].take().unwrap().parallel(),595state,596join_handles,597)598},599GroupByState::Source(source) => {600assert!(recv_ports[0].is_none());601source.spawn(scope, &mut [], send_ports, state, join_handles);602},603GroupByState::Done => unreachable!(),604}605}606}607608609