Path: blob/main/crates/polars-stream/src/nodes/io_sinks/partition/by_key.rs
6939 views
use std::cmp::Reverse;1use std::pin::Pin;2use std::sync::{Arc, OnceLock};34use futures::StreamExt;5use futures::stream::FuturesUnordered;6use polars_core::config;7use polars_core::frame::DataFrame;8use polars_core::prelude::{Column, PlHashSet, PlIndexMap, row_encode};9use polars_core::schema::SchemaRef;10use polars_core::utils::arrow::buffer::Buffer;11use polars_error::PolarsResult;12use polars_plan::dsl::{PartitionTargetCallback, SinkFinishCallback, SinkOptions};13use polars_utils::pl_str::PlSmallStr;14use polars_utils::plpath::PlPath;15use polars_utils::priority::Priority;1617use super::{CreateNewSinkFn, PerPartitionSortBy};18use crate::async_executor::{AbortOnDropHandle, spawn};19use crate::async_primitives::connector::connector;20use crate::execute::StreamingExecutionState;21use crate::morsel::SourceToken;22use crate::nodes::io_sinks::metrics::WriteMetrics;23use crate::nodes::io_sinks::partition::{SinkSender, open_new_sink};24use crate::nodes::io_sinks::phase::PhaseOutcome;25use crate::nodes::io_sinks::{SinkInputPort, SinkNode, parallelize_receive_task};26use crate::nodes::{JoinHandle, Morsel, MorselSeq, TaskPriority};2728type Linearized =29Priority<Reverse<MorselSeq>, (SourceToken, Vec<(Buffer<u8>, Vec<Column>, DataFrame)>)>;30pub struct PartitionByKeySinkNode {31input_schema: SchemaRef,32// This is not be the same as the input_schema, e.g. when include_key=false then this will not33// include the keys columns.34sink_input_schema: SchemaRef,3536key_cols: Arc<[PlSmallStr]>,3738max_open_partitions: usize,39include_key: bool,4041base_path: Arc<PlPath>,42file_path_cb: Option<PartitionTargetCallback>,43create_new: CreateNewSinkFn,44ext: PlSmallStr,4546sink_options: SinkOptions,4748per_partition_sort_by: Option<PerPartitionSortBy>,49written_partitions: Arc<OnceLock<DataFrame>>,50finish_callback: Option<SinkFinishCallback>,51}5253impl PartitionByKeySinkNode {54#[allow(clippy::too_many_arguments)]55pub fn new(56input_schema: SchemaRef,57key_cols: Arc<[PlSmallStr]>,58base_path: Arc<PlPath>,59file_path_cb: Option<PartitionTargetCallback>,60create_new: CreateNewSinkFn,61ext: PlSmallStr,62sink_options: SinkOptions,63include_key: bool,64per_partition_sort_by: Option<PerPartitionSortBy>,65finish_callback: Option<SinkFinishCallback>,66) -> Self {67assert!(!key_cols.is_empty());6869let mut sink_input_schema = input_schema.clone();70if !include_key {71let keys_col_hm = PlHashSet::from_iter(key_cols.iter().map(|s| s.as_str()));72sink_input_schema = Arc::new(73sink_input_schema74.try_project(75input_schema76.iter_names()77.filter(|n| !keys_col_hm.contains(n.as_str()))78.cloned(),79)80.unwrap(),81);82}8384const DEFAULT_MAX_OPEN_PARTITIONS: usize = 128;85let max_open_partitions =86std::env::var("POLARS_MAX_OPEN_PARTITIONS").map_or(DEFAULT_MAX_OPEN_PARTITIONS, |v| {87v.parse::<usize>()88.expect("unable to parse POLARS_MAX_OPEN_PARTITIONS")89});9091Self {92input_schema,93sink_input_schema,94key_cols,95max_open_partitions,96include_key,97base_path,98file_path_cb,99create_new,100ext,101sink_options,102per_partition_sort_by,103written_partitions: Arc::new(OnceLock::new()),104finish_callback,105}106}107}108109impl SinkNode for PartitionByKeySinkNode {110fn name(&self) -> &str {111"partition-by-key-sink"112}113114fn is_sink_input_parallel(&self) -> bool {115true116}117118fn do_maintain_order(&self) -> bool {119self.sink_options.maintain_order120}121122fn initialize(&mut self, _state: &StreamingExecutionState) -> PolarsResult<()> {123Ok(())124}125126fn spawn_sink(127&mut self,128recv_port_rx: crate::async_primitives::connector::Receiver<(PhaseOutcome, SinkInputPort)>,129state: &StreamingExecutionState,130join_handles: &mut Vec<JoinHandle<polars_error::PolarsResult<()>>>,131) {132let (io_tx, mut io_rx) = connector();133let pass_rxs = parallelize_receive_task::<Linearized>(134join_handles,135recv_port_rx,136state.num_pipelines,137self.sink_options.maintain_order,138io_tx,139);140141join_handles.extend(pass_rxs.into_iter().map(|mut pass_rx| {142let key_cols = self.key_cols.clone();143let stable = self.sink_options.maintain_order;144let include_key = self.include_key;145146spawn(TaskPriority::High, async move {147while let Ok((mut rx, mut lin_tx)) = pass_rx.recv().await {148while let Ok(morsel) = rx.recv().await {149let (df, seq, source_token, consume_token) = morsel.into_inner();150151let partition_include_key = true; // We need the keys to send to the152// appropriate sink.153let parallel = false; // We handle parallel processing in the streaming154// engine.155let partitions = df._partition_by_impl(156&key_cols,157stable,158partition_include_key,159parallel,160)?;161162let partitions = partitions163.into_iter()164.map(|mut df| {165let keys = df.select_columns(key_cols.iter().cloned())?;166let keys = keys167.into_iter()168.map(|c| c.head(Some(1)))169.collect::<Vec<_>>();170171let row_encoded = row_encode::encode_rows_unordered(&keys)?172.downcast_into_iter()173.next()174.unwrap();175let row_encoded = row_encoded.into_inner().2;176177if !include_key {178df = df.drop_many(key_cols.iter().cloned());179}180181PolarsResult::Ok((row_encoded, keys, df))182})183.collect::<PolarsResult<Vec<(Buffer<u8>, Vec<Column>, DataFrame)>>>()?;184185if lin_tx186.insert(Priority(Reverse(seq), (source_token, partitions)))187.await188.is_err()189{190return Ok(());191}192// It is important that we don't pass the consume193// token to the sinks, because that leads to194// deadlocks.195drop(consume_token);196}197}198199Ok(())200})201}));202203let state = state.clone();204let input_schema = self.input_schema.clone();205let key_cols = self.key_cols.clone();206let sink_input_schema = self.sink_input_schema.clone();207let max_open_partitions = self.max_open_partitions;208let base_path = self.base_path.clone();209let file_path_cb = self.file_path_cb.clone();210let create_new_sink = self.create_new.clone();211let ext = self.ext.clone();212let per_partition_sort_by = self.per_partition_sort_by.clone();213let output_written_partitions = self.written_partitions.clone();214join_handles.push(spawn(TaskPriority::High, async move {215enum OpenPartition {216Sink {217sender: SinkSender,218join_handles: FuturesUnordered<AbortOnDropHandle<PolarsResult<()>>>,219node: Box<dyn SinkNode + Send>,220keys: Vec<Column>,221},222Buffer {223buffered: Vec<DataFrame>,224keys: Vec<Column>,225},226}227228let verbose = config::verbose();229let mut file_idx = 0;230let mut open_partitions: PlIndexMap<Buffer<u8>, OpenPartition> = PlIndexMap::default();231232// Wrap this in a closure so that a failure to send (which signifies a failure) can be233// caught while waiting for tasks.234let mut receive_and_pass = async || {235while let Ok(mut lin_rx) = io_rx.recv().await {236while let Some(Priority(Reverse(seq), (source_token, partitions))) =237lin_rx.get().await238{239for (row_encoded, keys, partition) in partitions {240let num_open_partitions = open_partitions.len();241let open_partition = match open_partitions.get_mut(&row_encoded) {242None if num_open_partitions >= max_open_partitions => {243if num_open_partitions == max_open_partitions && verbose {244eprintln!(245"[partition[by-key]]: Reached maximum open partitions. Buffering the rest to memory before writing.",246);247}248249let (idx, previous) = open_partitions.insert_full(250row_encoded,251OpenPartition::Buffer { buffered: Vec::new(), keys },252);253debug_assert!(previous.is_none());254open_partitions.get_index_mut(idx).unwrap().1255},256None => {257let result = open_new_sink(258base_path.as_ref().as_ref(),259file_path_cb.as_ref(),260super::default_by_key_file_path_cb,261file_idx,262file_idx,2630,264Some(keys.as_slice()),265&create_new_sink,266sink_input_schema.clone(),267"by-key",268ext.as_str(),269verbose,270&state,271per_partition_sort_by.as_ref(),272).await?;273file_idx += 1;274275let Some((join_handles, sender, node)) = result else {276return Ok(());277};278279let (idx, previous) = open_partitions.insert_full(280row_encoded,281OpenPartition::Sink { sender, join_handles, node, keys },282);283debug_assert!(previous.is_none());284open_partitions.get_index_mut(idx).unwrap().1285},286Some(open_partition) => open_partition,287};288289match open_partition {290OpenPartition::Sink { sender, .. } => {291let morsel = Morsel::new(partition, seq, source_token.clone());292if sender.send(morsel).await.is_err() {293return Ok(());294}295},296OpenPartition::Buffer { buffered, .. } => buffered.push(partition),297}298}299}300}301302PolarsResult::Ok(())303};304receive_and_pass().await?;305306let mut partition_metrics = Vec::with_capacity(file_idx);307308// At this point, we need to wait for all sinks to finish writing and close them. Also,309// sinks that ended up buffering need to output their data.310for open_partition in open_partitions.into_values() {311let (sender, mut join_handles, mut node, keys) = match open_partition {312OpenPartition::Sink { sender, join_handles, node, keys } => (sender, join_handles, node, keys),313OpenPartition::Buffer { buffered, keys } => {314let result = open_new_sink(315base_path.as_ref().as_ref(),316file_path_cb.as_ref(),317super::default_by_key_file_path_cb,318file_idx,319file_idx,3200,321Some(keys.as_slice()),322&create_new_sink,323sink_input_schema.clone(),324"by-key",325ext.as_str(),326verbose,327&state,328per_partition_sort_by.as_ref(),329).await?;330file_idx += 1;331let Some((join_handles, mut sender, node)) = result else {332return Ok(());333};334335let source_token = SourceToken::new();336let mut seq = MorselSeq::default();337for df in buffered {338let morsel = Morsel::new(df, seq, source_token.clone());339if sender.send(morsel).await.is_err() {340return Ok(());341}342seq = seq.successor();343}344345(sender, join_handles, node, keys)346},347};348349drop(sender); // Signal to the sink that nothing more is coming.350while let Some(res) = join_handles.next().await {351res?;352}353354if let Some(mut metrics) = node.get_metrics()? {355metrics.keys = Some(keys.into_iter().map(|c| c.get(0).unwrap().into_static()).collect());356partition_metrics.push(metrics);357}358if let Some(finalize) = node.finalize(&state) {359finalize.await?;360}361}362363let df = WriteMetrics::collapse_to_df(partition_metrics, &sink_input_schema, Some(&input_schema.try_project(key_cols.iter()).unwrap()));364output_written_partitions.set(df).unwrap();365Ok(())366}));367}368369fn finalize(370&mut self,371_state: &StreamingExecutionState,372) -> Option<Pin<Box<dyn Future<Output = PolarsResult<()>> + Send>>> {373let finish_callback = self.finish_callback.clone();374let written_partitions = self.written_partitions.clone();375376Some(Box::pin(async move {377if let Some(finish_callback) = &finish_callback {378let df = written_partitions.get().unwrap();379finish_callback.call(df.clone())?;380}381Ok(())382}))383}384}385386387