Path: blob/main/crates/polars-stream/src/nodes/io_sinks/mod.rs
6939 views
use std::pin::Pin;1use std::sync::{Arc, LazyLock, Mutex};23use futures::StreamExt;4use futures::stream::FuturesUnordered;5use polars_core::config;6use polars_core::frame::DataFrame;7use polars_core::prelude::Column;8use polars_core::schema::SchemaRef;9use polars_error::PolarsResult;1011use self::metrics::WriteMetrics;12use super::{ComputeNode, JoinHandle, Morsel, PortState, RecvPort, SendPort, TaskScope};13use crate::async_executor::{AbortOnDropHandle, spawn};14use crate::async_primitives::connector::{Receiver, Sender, connector};15use crate::async_primitives::distributor_channel;16use crate::async_primitives::linearizer::{Inserter, Linearizer};17use crate::async_primitives::wait_group::WaitGroup;18use crate::execute::StreamingExecutionState;19use crate::nodes::TaskPriority;2021mod metrics;22mod phase;23use phase::PhaseOutcome;2425#[cfg(feature = "csv")]26pub mod csv;27#[cfg(feature = "ipc")]28pub mod ipc;29#[cfg(feature = "json")]30pub mod json;31#[cfg(feature = "parquet")]32pub mod parquet;33pub mod partition;3435// This needs to be low to increase the backpressure.36static DEFAULT_SINK_LINEARIZER_BUFFER_SIZE: LazyLock<usize> = LazyLock::new(|| {37std::env::var("POLARS_DEFAULT_SINK_LINEARIZER_BUFFER_SIZE")38.map(|x| x.parse().unwrap())39.unwrap_or(1)40});4142static DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE: LazyLock<usize> = LazyLock::new(|| {43std::env::var("POLARS_DEFAULT_SINK_DISTRIBUTOR_BUFFER_SIZE")44.map(|x| x.parse().unwrap())45.unwrap_or(1)46});4748pub enum SinkInputPort {49Serial(Receiver<Morsel>),50Parallel(Vec<Receiver<Morsel>>),51}5253impl SinkInputPort {54pub fn serial(self) -> Receiver<Morsel> {55match self {56Self::Serial(s) => s,57_ => panic!(),58}59}6061pub fn parallel(self) -> Vec<Receiver<Morsel>> {62match self {63Self::Parallel(s) => s,64_ => panic!(),65}66}67}6869/// Spawn a task that linearizes and buffers morsels until a given a maximum chunk size is reached70/// and then distributes the columns amongst worker tasks.71fn buffer_and_distribute_columns_task(72mut recv_port_rx: Receiver<(PhaseOutcome, SinkInputPort)>,73mut dist_tx: distributor_channel::Sender<(usize, usize, Column)>,74chunk_size: usize,75schema: SchemaRef,76metrics: Arc<Mutex<Option<WriteMetrics>>>,77) -> JoinHandle<PolarsResult<()>> {78spawn(TaskPriority::High, async move {79let mut seq = 0usize;80let mut buffer = DataFrame::empty_with_schema(schema.as_ref());8182let mut metrics_ = metrics.lock().unwrap().take();83while let Ok((outcome, rx)) = recv_port_rx.recv().await {84let mut rx = rx.serial();85while let Ok(morsel) = rx.recv().await {86let (df, _, _, consume_token) = morsel.into_inner();8788if let Some(metrics) = metrics_.as_mut() {89metrics.append(&df)?;90}9192// @NOTE: This also performs schema validation.93buffer.vstack_mut(&df)?;9495while buffer.height() >= chunk_size {96let df;97(df, buffer) = buffer.split_at(buffer.height().min(chunk_size) as i64);9899for (i, column) in df.take_columns().into_iter().enumerate() {100if dist_tx.send((seq, i, column)).await.is_err() {101return Ok(());102}103}104seq += 1;105}106drop(consume_token); // Increase the backpressure. Only free up a pipeline when the107// morsel has started encoding in its entirety. This still108// allows for parallelism of Morsels, but prevents large109// bunches of Morsels from stacking up here.110}111112outcome.stopped();113}114if let Some(metrics_) = metrics_ {115*metrics.lock().unwrap() = Some(metrics_);116}117118// Don't write an empty row group at the end.119if buffer.is_empty() {120return Ok(());121}122123// Flush the remaining rows.124assert!(buffer.height() <= chunk_size);125for (i, column) in buffer.take_columns().into_iter().enumerate() {126if dist_tx.send((seq, i, column)).await.is_err() {127return Ok(());128}129}130131PolarsResult::Ok(())132})133}134135#[allow(clippy::type_complexity)]136pub fn parallelize_receive_task<T: Ord + Send + 'static>(137join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,138mut recv_port_rx: Receiver<(PhaseOutcome, SinkInputPort)>,139num_pipelines: usize,140maintain_order: bool,141mut io_tx: Sender<Linearizer<T>>,142) -> Vec<Receiver<(Receiver<Morsel>, Inserter<T>)>> {143// Phase Handling Task -> Encode Tasks.144let (mut pass_txs, pass_rxs) = (0..num_pipelines)145.map(|_| connector())146.collect::<(Vec<_>, Vec<_>)>();147148join_handles.push(spawn(TaskPriority::High, async move {149while let Ok((outcome, port_rxs)) = recv_port_rx.recv().await {150let port_rxs = port_rxs.parallel();151let (lin_rx, lin_txs) = Linearizer::<T>::new_with_maintain_order(152num_pipelines,153*DEFAULT_SINK_LINEARIZER_BUFFER_SIZE,154maintain_order,155);156157for ((pass_tx, port_rx), lin_tx) in pass_txs.iter_mut().zip(port_rxs).zip(lin_txs) {158if pass_tx.send((port_rx, lin_tx)).await.is_err() {159return Ok(());160}161}162if io_tx.send(lin_rx).await.is_err() {163return Ok(());164}165166outcome.stopped();167}168169Ok(())170}));171172pass_rxs173}174175pub trait SinkNode {176fn name(&self) -> &str;177178fn is_sink_input_parallel(&self) -> bool;179180fn do_maintain_order(&self) -> bool {181true182}183184fn spawn_sink(185&mut self,186recv_ports_recv: Receiver<(PhaseOutcome, SinkInputPort)>,187state: &StreamingExecutionState,188join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,189);190191/// Callback that gets called once before the sink is spawned.192fn initialize(&mut self, state: &StreamingExecutionState) -> PolarsResult<()> {193_ = state;194Ok(())195}196197/// Callback for when the query has finished successfully.198///199/// This should only be called when the writing is finished and all the join handles have been200/// awaited.201fn finalize(202&mut self,203state: &StreamingExecutionState,204) -> Option<Pin<Box<dyn Future<Output = PolarsResult<()>> + Send>>> {205_ = state;206None207}208209/// Fetch metrics for a specific sink.210///211/// This should only be called when the writing is finished and all the join handles have been212/// awaited.213fn get_metrics(&self) -> PolarsResult<Option<WriteMetrics>> {214Ok(None)215}216}217218/// The state needed to manage a spawned [`SinkNode`].219struct StartedSinkComputeNode {220input_send: Sender<(PhaseOutcome, SinkInputPort)>,221join_handles: FuturesUnordered<AbortOnDropHandle<PolarsResult<()>>>,222}223224/// A [`ComputeNode`] to wrap a [`SinkNode`].225pub struct SinkComputeNode {226sink: Box<dyn SinkNode + Send>,227started: Option<StartedSinkComputeNode>,228state: SinkState,229}230231enum SinkState {232/// Initial state of a [`SinkComputeNode`].233///234/// This still requires `sink.initialize` to be called on the `SinkNode`.235Uninitialized,236237/// Active state of a [`SinkComputeNode`].238///239/// When finished, the `sink.finalize` method should be called.240Initialized,241242/// Final state for the [`SinkComputeNode`].243///244/// Receive port is Done and [`SinkNode`] is finalized.245Finished,246}247248impl SinkComputeNode {249pub fn new(sink: Box<dyn SinkNode + Send>) -> Self {250Self {251sink,252started: None,253state: SinkState::Uninitialized,254}255}256}257258impl<T: SinkNode + Send + 'static> From<T> for SinkComputeNode {259fn from(value: T) -> Self {260Self::new(Box::new(value))261}262}263264impl ComputeNode for SinkComputeNode {265fn name(&self) -> &str {266self.sink.name()267}268269fn update_state(270&mut self,271recv: &mut [PortState],272_send: &mut [PortState],273state: &StreamingExecutionState,274) -> PolarsResult<()> {275// Ensure that initialize is only called once.276if matches!(self.state, SinkState::Uninitialized) {277self.sink.initialize(state)?;278self.state = SinkState::Initialized;279}280281if recv[0] != PortState::Done {282recv[0] = PortState::Ready;283}284285if recv[0] == PortState::Done && !matches!(self.state, SinkState::Finished) {286let started = self.started.take();287let finalize = self.sink.finalize(state);288289state.spawn_subphase_task(async move {290// We need to join on all started tasks before finalizing the node because the291// unfinished tasks might still need access to the node.292//293// Note, that if the sink never received any data, this `started` might be None.294// However, we do still need to finalize the node otherwise no file will be295// created.296if let Some(mut started) = started {297drop(started.input_send);298// Either the task finished or some error occurred.299while let Some(ret) = started.join_handles.next().await {300ret?;301}302}303304if let Some(finalize) = finalize {305finalize.await?;306}307308PolarsResult::Ok(())309});310311self.state = SinkState::Finished;312}313314Ok(())315}316317fn spawn<'env, 's>(318&'env mut self,319scope: &'s TaskScope<'s, 'env>,320recv_ports: &mut [Option<RecvPort<'_>>],321send_ports: &mut [Option<SendPort<'_>>],322state: &'s StreamingExecutionState,323join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,324) {325assert_eq!(recv_ports.len(), 1);326assert!(send_ports.is_empty());327328let name = self.name().to_string();329let started = self.started.get_or_insert_with(|| {330let (tx, rx) = connector();331let mut join_handles = Vec::new();332333self.sink.spawn_sink(rx, state, &mut join_handles);334// One of the tasks might throw an error. In which case, we need to cancel all335// handles and find the error.336let join_handles: FuturesUnordered<_> =337join_handles.drain(..).map(AbortOnDropHandle::new).collect();338339StartedSinkComputeNode {340input_send: tx,341join_handles,342}343});344345let wait_group = WaitGroup::default();346let recv = recv_ports[0].take().unwrap();347let sink_input = if self.sink.is_sink_input_parallel() {348SinkInputPort::Parallel(recv.parallel())349} else {350SinkInputPort::Serial(recv.serial_with_maintain_order(self.sink.do_maintain_order()))351};352join_handles.push(scope.spawn_task(TaskPriority::High, async move {353let (token, outcome) = PhaseOutcome::new_shared_wait(wait_group.token());354if started.input_send.send((outcome, sink_input)).await.is_ok() {355// Wait for the phase to finish.356wait_group.wait().await;357if !token.did_finish() {358return Ok(());359}360361if config::verbose() {362eprintln!("[{name}]: Last data sent.");363}364}365366// Either the task finished or some error occurred.367while let Some(ret) = started.join_handles.next().await {368ret?;369}370371Ok(())372}));373}374375fn get_output(&mut self) -> PolarsResult<Option<DataFrame>> {376Ok(None)377}378}379380381