Path: blob/main/crates/polars-stream/src/nodes/group_by.rs
8433 views
use std::sync::Arc;12use polars_core::POOL;3use polars_core::prelude::{IntoColumn, PlHashSet, PlRandomState};4use polars_core::schema::Schema;5use polars_core::utils::accumulate_dataframes_vertical_unchecked;6use polars_expr::groups::Grouper;7use polars_expr::hash_keys::HashKeys;8use polars_expr::hot_groups::{HotGrouper, new_hash_hot_grouper};9use polars_expr::reduce::GroupedReduction;10use polars_utils::cardinality_sketch::CardinalitySketch;11use polars_utils::hashing::HashPartitioner;12use polars_utils::itertools::Itertools;13use polars_utils::pl_str::PlSmallStr;14use polars_utils::sparse_init_vec::SparseInitVec;15use polars_utils::{IdxSize, UnitVec};16use rayon::prelude::*;17use tokio::sync::mpsc::{Receiver, channel};1819use super::compute_node_prelude::*;20use crate::async_executor;21use crate::expression::StreamExpr;22use crate::morsel::get_ideal_morsel_size;23use crate::nodes::in_memory_source::InMemorySourceNode;2425#[cfg(debug_assertions)]26const DEFAULT_HOT_TABLE_SIZE: usize = 4;27#[cfg(not(debug_assertions))]28const DEFAULT_HOT_TABLE_SIZE: usize = 4096;2930struct PreAgg {31keys: HashKeys,32reduction_idxs: UnitVec<usize>,33reductions: Vec<Box<dyn GroupedReduction>>,34}3536struct LocalGroupBySinkState {37hot_grouper_per_input: Vec<Box<dyn HotGrouper>>,38hot_grouped_reductions: Vec<Box<dyn GroupedReduction>>,3940// A cardinality sketch per partition for the keys seen by this builder.41sketch_per_p: Vec<CardinalitySketch>,4243// morsel_idxs_values_per_p[p][start..stop] contains the offsets into cold_morsels[i]44// for partition p, where start, stop are:45// let start = morsel_idxs_offsets[i * num_partitions + p];46// let stop = morsel_idxs_offsets[(i + 1) * num_partitions + p];47cold_morsels: Vec<(usize, u64, HashKeys, DataFrame)>,48morsel_idxs_values_per_p: Vec<Vec<IdxSize>>,49morsel_idxs_offsets_per_p: Vec<usize>,5051// Similar to the above, but for (evicted) pre-aggregates.52// The UnitVec contains the indices of the grouped reductions.53pre_aggs: Vec<PreAgg>,54pre_agg_idxs_values_per_p: Vec<Vec<IdxSize>>,55pre_agg_idxs_offsets_per_p: Vec<usize>,56}5758impl LocalGroupBySinkState {59fn new(60key_schema: Arc<Schema>,61reductions: Vec<Box<dyn GroupedReduction>>,62hot_table_size: usize,63num_partitions: usize,64num_inputs: usize,65) -> Self {66let hot_grouper_per_input = (0..num_inputs)67.map(|_| new_hash_hot_grouper(key_schema.clone(), hot_table_size))68.collect();69Self {70hot_grouper_per_input,71hot_grouped_reductions: reductions,7273sketch_per_p: vec![CardinalitySketch::new(); num_partitions],7475cold_morsels: Vec::new(),76morsel_idxs_values_per_p: vec![Vec::new(); num_partitions],77morsel_idxs_offsets_per_p: vec![0; num_partitions],7879pre_aggs: Vec::new(),80pre_agg_idxs_values_per_p: vec![Vec::new(); num_partitions],81pre_agg_idxs_offsets_per_p: vec![0; num_partitions],82}83}8485fn flush_evictions(86&mut self,87input_idx: usize,88reduction_idxs: &[usize],89partitioner: &HashPartitioner,90) {91let hash_keys = self.hot_grouper_per_input[input_idx].take_evicted_keys();92let reductions = reduction_idxs93.iter()94.map(|r| self.hot_grouped_reductions[*r].take_evictions())95.collect_vec();96self.add_pre_agg(hash_keys, reduction_idxs, reductions, partitioner);97}9899fn add_pre_agg(100&mut self,101hash_keys: HashKeys,102reduction_idxs: &[usize],103reductions: Vec<Box<dyn GroupedReduction>>,104partitioner: &HashPartitioner,105) {106hash_keys.gen_idxs_per_partition(107partitioner,108&mut self.pre_agg_idxs_values_per_p,109&mut self.sketch_per_p,110true,111);112self.pre_agg_idxs_offsets_per_p113.extend(self.pre_agg_idxs_values_per_p.iter().map(|vp| vp.len()));114let pre_agg = PreAgg {115keys: hash_keys,116reduction_idxs: UnitVec::from_slice(reduction_idxs),117reductions,118};119self.pre_aggs.push(pre_agg);120}121}122123struct GroupBySinkState {124key_selectors_per_input: Vec<Vec<StreamExpr>>,125reductions_per_input: Vec<Vec<usize>>,126grouper: Box<dyn Grouper>,127uniq_grouped_reduction_cols_per_input: Vec<Vec<PlSmallStr>>,128grouped_reduction_cols: Vec<Vec<PlSmallStr>>,129grouped_reductions: Vec<Box<dyn GroupedReduction>>,130locals: Vec<LocalGroupBySinkState>,131random_state: PlRandomState,132partitioner: HashPartitioner,133has_order_sensitive_agg: bool,134}135136impl GroupBySinkState {137fn spawn<'env, 's>(138&'env mut self,139scope: &'s TaskScope<'s, 'env>,140receivers: Vec<Receiver<(usize, Morsel)>>,141state: &'s StreamingExecutionState,142join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,143) {144for (mut recv, local) in receivers.into_iter().zip(&mut self.locals) {145let key_selectors_per_input = &self.key_selectors_per_input;146let reductions_per_input = &self.reductions_per_input;147let uniq_grouped_reduction_cols_per_input = &self.uniq_grouped_reduction_cols_per_input;148let grouped_reduction_cols = &self.grouped_reduction_cols;149let random_state = &self.random_state;150let partitioner = self.partitioner.clone();151let has_order_sensitive_agg = self.has_order_sensitive_agg;152join_handles.push(scope.spawn_task(TaskPriority::High, async move {153let mut hot_idxs = Vec::new();154let mut hot_group_idxs = Vec::new();155let mut cold_idxs = Vec::new();156let mut in_cols = Vec::new();157while let Some((input_idx, morsel)) = recv.recv().await {158// Compute hot group indices from key.159let seq = morsel.seq().to_u64();160let mut df = morsel.into_df();161let mut key_columns = Vec::new();162for selector in &key_selectors_per_input[input_idx] {163let s = selector.evaluate(&df, &state.in_memory_exec_state).await?;164key_columns.push(s.into_column());165}166let keys = unsafe {167DataFrame::new_unchecked_with_broadcast(df.height(), key_columns)?168};169let hash_keys = HashKeys::from_df(&keys, random_state.clone(), true, false);170171let hot_grouper = &mut local.hot_grouper_per_input[input_idx];172hot_idxs.clear();173hot_group_idxs.clear();174cold_idxs.clear();175hot_grouper.insert_keys(176&hash_keys,177&mut hot_idxs,178&mut hot_group_idxs,179&mut cold_idxs,180has_order_sensitive_agg,181);182183// Drop columns not used for reductions (key-only columns).184let uniq_grouped_reduction_cols =185&uniq_grouped_reduction_cols_per_input[input_idx];186if uniq_grouped_reduction_cols.len() < df.width() {187df = unsafe { df.select_unchecked(uniq_grouped_reduction_cols.as_slice()) }188.unwrap();189}190df.rechunk_mut(); // For gathers.191192// Update hot reductions.193for red_idx in &reductions_per_input[input_idx] {194let cols = &grouped_reduction_cols[*red_idx];195let reduction = &mut local.hot_grouped_reductions[*red_idx];196for col in cols {197in_cols.push(df.column(col).unwrap());198}199unsafe {200// SAFETY: we resize the reduction to the number of groups beforehand.201reduction.resize(hot_grouper.num_groups());202reduction.update_groups_while_evicting(203&in_cols,204&hot_idxs,205&hot_group_idxs,206seq,207)?;208}209in_cols.clear();210in_cols = in_cols.into_iter().map(|_| unreachable!()).collect(); // Clear lifetimes.211}212213// Store cold keys.214// TODO: don't always gather, if majority cold simply store all and remember offsets into it.215if !cold_idxs.is_empty() {216unsafe {217let cold_keys = hash_keys.gather_unchecked(&cold_idxs);218let cold_df = df.take_slice_unchecked_impl(&cold_idxs, false);219220cold_keys.gen_idxs_per_partition(221&partitioner,222&mut local.morsel_idxs_values_per_p,223&mut local.sketch_per_p,224true,225);226local227.morsel_idxs_offsets_per_p228.extend(local.morsel_idxs_values_per_p.iter().map(|vp| vp.len()));229local230.cold_morsels231.push((input_idx, seq, cold_keys, cold_df));232}233}234235// If we have too many evicted rows, flush them.236if hot_grouper.num_evictions() >= get_ideal_morsel_size() {237local.flush_evictions(238input_idx,239&reductions_per_input[input_idx],240&partitioner,241);242}243}244Ok(())245}));246}247}248249fn combine_locals(&mut self) -> PolarsResult<Vec<GroupByPartition>> {250// Finalize pre-aggregations.251POOL.install(|| {252self.locals253.as_mut_slice()254.into_par_iter()255.with_max_len(1)256.for_each(|l| {257for (input_idx, r_idxs) in self.reductions_per_input.iter().enumerate() {258let hot_grouper = &mut l.hot_grouper_per_input[input_idx];259if hot_grouper.num_evictions() > 0 {260l.flush_evictions(input_idx, r_idxs, &self.partitioner);261}262}263264let mut opt_hot_reductions =265l.hot_grouped_reductions.drain(..).map(Some).collect_vec();266for (input_idx, r_idxs) in self.reductions_per_input.iter().enumerate() {267let hot_grouper = &mut l.hot_grouper_per_input[input_idx];268let hot_keys = hot_grouper.keys();269let hot_reductions = r_idxs270.iter()271.map(|r| opt_hot_reductions[*r].take().unwrap())272.collect_vec();273l.add_pre_agg(hot_keys, r_idxs, hot_reductions, &self.partitioner);274}275});276});277278// To reduce maximum memory usage we want to drop the morsels279// as soon as they're processed, so we move into Arcs. The drops might280// also be expensive, so instead of directly dropping we put that on281// a work queue.282let morsels_per_local = self283.locals284.iter_mut()285.map(|l| Arc::new(core::mem::take(&mut l.cold_morsels)))286.collect_vec();287let pre_aggs_per_local = self288.locals289.iter_mut()290.map(|l| Arc::new(core::mem::take(&mut l.pre_aggs)))291.collect_vec();292enum ToDrop<A, B> {293A(A),294B(B),295}296let (drop_q_send, drop_q_recv) = async_channel::bounded(self.locals.len());297let num_partitions = self.locals[0].sketch_per_p.len();298let output_per_partition: SparseInitVec<GroupByPartition> =299SparseInitVec::with_capacity(num_partitions);300let locals = &self.locals;301let grouper_template = &self.grouper;302let reductions_per_input = &self.reductions_per_input;303let grouped_reductions_template = &self.grouped_reductions;304let grouped_reduction_cols = &self.grouped_reduction_cols;305306async_executor::task_scope(|s| {307// Wrap in outer Arc to move to each thread, performing the308// expensive clone on that thread.309let arc_morsels_per_local = Arc::new(morsels_per_local);310let arc_pre_aggs_per_local = Arc::new(pre_aggs_per_local);311let mut join_handles = Vec::new();312for p in 0..num_partitions {313let arc_morsels_per_local = Arc::clone(&arc_morsels_per_local);314let arc_pre_aggs_per_local = Arc::clone(&arc_pre_aggs_per_local);315let drop_q_send = drop_q_send.clone();316let drop_q_recv = drop_q_recv.clone();317let output_per_partition = &output_per_partition;318join_handles.push(s.spawn_task(TaskPriority::High, async move {319// Extract from outer arc and drop outer arc.320let morsels_per_local = Arc::unwrap_or_clone(arc_morsels_per_local);321let pre_aggs_per_local = Arc::unwrap_or_clone(arc_pre_aggs_per_local);322323// Compute cardinality estimate and total amount of324// payload for this partition.325let mut sketch = CardinalitySketch::new();326for l in locals {327sketch.combine(&l.sketch_per_p[p]);328}329330// Allocate grouper and reductions.331let est_num_groups = sketch.estimate() * 5 / 4;332let mut p_grouper = grouper_template.new_empty();333let mut p_reductions = grouped_reductions_template334.iter()335.map(|gr| gr.new_empty())336.collect_vec();337p_grouper.reserve(est_num_groups);338for r in &mut p_reductions {339r.reserve(est_num_groups);340}341342// Insert morsels.343let mut skip_drop_attempt = false;344let mut group_idxs = Vec::new();345let mut in_cols = Vec::new();346for (l, l_morsels) in locals.iter().zip(morsels_per_local) {347// Try to help with dropping.348if !skip_drop_attempt {349drop(drop_q_recv.try_recv());350}351352for (i, morsel) in l_morsels.iter().enumerate() {353let (input_idx, seq_id, keys, morsel_df) = morsel;354unsafe {355let p_morsel_idxs_start =356l.morsel_idxs_offsets_per_p[i * num_partitions + p];357let p_morsel_idxs_stop =358l.morsel_idxs_offsets_per_p[(i + 1) * num_partitions + p];359let p_morsel_idxs = &l.morsel_idxs_values_per_p[p]360[p_morsel_idxs_start..p_morsel_idxs_stop];361362group_idxs.clear();363p_grouper.insert_keys_subset(364keys,365p_morsel_idxs,366Some(&mut group_idxs),367);368369for red_idx in &reductions_per_input[*input_idx] {370let cols = &grouped_reduction_cols[*red_idx];371let reduction = &mut p_reductions[*red_idx];372for col in cols {373in_cols.push(morsel_df.column(col).unwrap());374}375reduction.resize(p_grouper.num_groups());376reduction.update_groups_subset(377&in_cols,378p_morsel_idxs,379&group_idxs,380*seq_id,381)?;382in_cols.clear();383}384}385}386in_cols = in_cols.into_iter().map(|_| unreachable!()).collect(); // Clear lifetimes.387388if let Some(l) = Arc::into_inner(l_morsels) {389// If we're the last thread to process this set of morsels we're probably390// falling behind the rest, since the drop can be quite expensive we skip391// a drop attempt hoping someone else will pick up the slack.392drop(drop_q_send.try_send(ToDrop::A(l)));393skip_drop_attempt = true;394} else {395skip_drop_attempt = false;396}397}398399// Insert pre-aggregates.400for (l, l_pre_aggs) in locals.iter().zip(pre_aggs_per_local) {401// Try to help with dropping.402if !skip_drop_attempt {403drop(drop_q_recv.try_recv());404}405406for (i, key_pre_aggs) in l_pre_aggs.iter().enumerate() {407let PreAgg {408keys,409reduction_idxs: r_idxs,410reductions: pre_aggs,411} = key_pre_aggs;412unsafe {413let p_pre_agg_idxs_start =414l.pre_agg_idxs_offsets_per_p[i * num_partitions + p];415let p_pre_agg_idxs_stop =416l.pre_agg_idxs_offsets_per_p[(i + 1) * num_partitions + p];417let p_pre_agg_idxs = &l.pre_agg_idxs_values_per_p[p]418[p_pre_agg_idxs_start..p_pre_agg_idxs_stop];419420group_idxs.clear();421p_grouper.insert_keys_subset(422keys,423p_pre_agg_idxs,424Some(&mut group_idxs),425);426for (pre_agg, r_idx) in pre_aggs.iter().zip(r_idxs.iter()) {427let r = &mut p_reductions[*r_idx];428r.resize(p_grouper.num_groups());429r.combine_subset(&**pre_agg, p_pre_agg_idxs, &group_idxs)?;430}431}432}433434if let Some(l) = Arc::into_inner(l_pre_aggs) {435// If we're the last thread to process this set of morsels we're probably436// falling behind the rest, since the drop can be quite expensive we skip437// a drop attempt hoping someone else will pick up the slack.438drop(drop_q_send.try_send(ToDrop::B(l)));439skip_drop_attempt = true;440} else {441skip_drop_attempt = false;442}443}444445// We're done, help others out by doing drops.446drop(drop_q_send); // So we don't deadlock trying to receive from ourselves.447while let Ok(to_drop) = drop_q_recv.recv().await {448drop(to_drop);449}450451output_per_partition452.try_set(453p,454GroupByPartition {455grouper: p_grouper,456grouped_reductions: p_reductions,457},458)459.ok()460.unwrap();461462PolarsResult::Ok(())463}));464}465466// Drop outer arc after spawning each thread so the inner arcs467// can get dropped as soon as they're processed. We also have to468// drop the drop queue sender so we don't deadlock waiting for it469// to end.470drop(arc_morsels_per_local);471drop(arc_pre_aggs_per_local);472drop(drop_q_send);473474polars_io::pl_async::get_runtime().block_on(async move {475for handle in join_handles {476handle.await?;477}478PolarsResult::Ok(())479})?;480PolarsResult::Ok(())481})?;482483// Drop remaining local state in parallel.484POOL.install(|| {485core::mem::take(&mut self.locals)486.into_par_iter()487.with_max_len(1)488.for_each(drop);489});490491Ok(output_per_partition.try_assume_init().ok().unwrap())492}493}494495struct GroupByPartition {496grouper: Box<dyn Grouper>,497grouped_reductions: Vec<Box<dyn GroupedReduction>>,498}499500impl GroupByPartition {501fn into_df(self, key_schema: &Schema, output_schema: &Schema) -> PolarsResult<DataFrame> {502let mut out = self.grouper.get_keys_in_group_order(key_schema);503let out_names = output_schema.iter_names().skip(out.width());504for (mut r, name) in self.grouped_reductions.into_iter().zip(out_names) {505unsafe {506out.push_column_unchecked(r.finalize()?.with_name(name.clone()).into_column());507}508}509Ok(out)510}511}512513enum GroupByState {514Sink(GroupBySinkState),515Source(InMemorySourceNode),516Done,517}518519pub struct GroupByNode {520state: GroupByState,521key_schema: Arc<Schema>,522num_inputs: usize,523num_pipelines: usize,524output_schema: Arc<Schema>,525}526527impl GroupByNode {528#[allow(clippy::too_many_arguments)]529pub fn new(530key_schema: Arc<Schema>,531// Input stream i selects keys with key_selectors_per_input[i].532key_selectors_per_input: Vec<Vec<StreamExpr>>,533// Input stream i feeds grouped_reductions[k] for each k in reductions_per_input[i].534reductions_per_input: Vec<Vec<usize>>,535grouper: Box<dyn Grouper>,536// grouped_reductions[k] is passed input cols grouped_reduction_cols[k].537grouped_reduction_cols: Vec<Vec<PlSmallStr>>,538grouped_reductions: Vec<Box<dyn GroupedReduction>>,539output_schema: Arc<Schema>,540random_state: PlRandomState,541num_pipelines: usize,542has_order_sensitive_agg: bool,543) -> Self {544let hot_table_size = std::env::var("POLARS_HOT_TABLE_SIZE")545.map(|sz| sz.parse::<usize>().unwrap())546.unwrap_or(DEFAULT_HOT_TABLE_SIZE);547let num_inputs = key_selectors_per_input.len();548let num_partitions = num_pipelines;549let uniq_grouped_reduction_cols_per_input = reductions_per_input550.iter()551.map(|rs| {552rs.iter()553.flat_map(|k| grouped_reduction_cols[*k].iter())554.cloned()555.collect::<PlHashSet<_>>()556.into_iter()557.collect_vec()558})559.collect_vec();560let locals = (0..num_pipelines)561.map(|_| {562let reductions = grouped_reductions.iter().map(|gr| gr.new_empty()).collect();563LocalGroupBySinkState::new(564key_schema.clone(),565reductions,566hot_table_size,567num_partitions,568num_inputs,569)570})571.collect();572let partitioner = HashPartitioner::new(num_partitions, 0);573Self {574state: GroupByState::Sink(GroupBySinkState {575key_selectors_per_input,576reductions_per_input,577grouped_reductions,578grouper,579random_state,580uniq_grouped_reduction_cols_per_input,581grouped_reduction_cols,582locals,583partitioner,584has_order_sensitive_agg,585}),586key_schema,587num_inputs,588num_pipelines,589output_schema,590}591}592}593594impl ComputeNode for GroupByNode {595fn name(&self) -> &str {596"group-by"597}598599fn update_state(600&mut self,601recv: &mut [PortState],602send: &mut [PortState],603state: &StreamingExecutionState,604) -> PolarsResult<()> {605assert!(recv.len() == self.num_inputs && send.len() == 1);606607// State transitions.608match &mut self.state {609// If the output doesn't want any more data, transition to being done.610_ if send[0] == PortState::Done => {611self.state = GroupByState::Done;612},613// All inputs is done, transition to being a source.614GroupByState::Sink(_) if recv.iter().all(|r| matches!(r, PortState::Done)) => {615let GroupByState::Sink(mut sink) =616core::mem::replace(&mut self.state, GroupByState::Done)617else {618unreachable!()619};620let partitions = sink.combine_locals()?;621let dfs = POOL.install(|| {622partitions623.into_par_iter()624.map(|p| p.into_df(&self.key_schema, &self.output_schema))625.collect::<Result<Vec<_>, _>>()626})?;627628let df = accumulate_dataframes_vertical_unchecked(dfs);629let source = InMemorySourceNode::new(Arc::new(df), MorselSeq::new(0));630self.state = GroupByState::Source(source);631},632// Defer to source node implementation.633GroupByState::Source(src) => {634src.update_state(&mut [], send, state)?;635if send[0] == PortState::Done {636self.state = GroupByState::Done;637}638},639// Nothing to change.640GroupByState::Done | GroupByState::Sink(_) => {},641}642643// Communicate our state.644match &self.state {645GroupByState::Sink { .. } => {646recv.fill(PortState::Ready);647send[0] = PortState::Blocked;648},649GroupByState::Source(..) => {650recv.fill(PortState::Done);651send[0] = PortState::Ready;652},653GroupByState::Done => {654recv.fill(PortState::Done);655send[0] = PortState::Done;656},657}658Ok(())659}660661fn spawn<'env, 's>(662&'env mut self,663scope: &'s TaskScope<'s, 'env>,664recv_ports: &mut [Option<RecvPort<'_>>],665send_ports: &mut [Option<SendPort<'_>>],666state: &'s StreamingExecutionState,667join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,668) {669assert!(send_ports.len() == 1 && recv_ports.len() == self.num_inputs);670match &mut self.state {671GroupByState::Sink(sink) => {672assert!(send_ports[0].is_none());673assert!(recv_ports.iter().any(|r| r.is_some()));674675// If we have multiple input streams merge them into one (still identifying which676// input stream it came from).677let (senders, receivers): (Vec<_>, Vec<_>) =678(0..self.num_pipelines).map(|_| channel(1)).unzip();679for (i, recv_port) in recv_ports.iter_mut().enumerate() {680if let Some(recv_port) = recv_port.take() {681for (mut r, s) in recv_port682.parallel()683.into_iter()684.zip(senders.iter().cloned())685{686join_handles.push(scope.spawn_task(TaskPriority::High, async move {687while let Ok(morsel) = r.recv().await {688if s.send((i, morsel)).await.is_err() {689break;690}691}692693Ok(())694}));695}696}697}698sink.spawn(scope, receivers, state, join_handles)699},700GroupByState::Source(source) => {701assert!(recv_ports[0].is_none());702source.spawn(scope, &mut [], send_ports, state, join_handles);703},704GroupByState::Done => unreachable!(),705}706}707}708709710