Path: blob/main/crates/polars-stream/src/nodes/callback_sink.rs
7884 views
use std::num::NonZeroUsize;12use polars_core::frame::DataFrame;3use polars_error::PolarsResult;4use polars_plan::prelude::PlanCallback;56use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};7use crate::execute::StreamingExecutionState;8use crate::graph::PortState;9use crate::nodes::ComputeNode;10use crate::pipe::{RecvPort, SendPort};1112pub struct CallbackSinkNode {13function: PlanCallback<DataFrame, bool>,14maintain_order: bool,1516buffer: DataFrame,17chunk_size: Option<NonZeroUsize>,18is_done: bool,19}2021impl CallbackSinkNode {22pub fn new(23function: PlanCallback<DataFrame, bool>,24maintain_order: bool,25chunk_size: Option<NonZeroUsize>,26) -> Self {27Self {28function,29maintain_order,3031buffer: DataFrame::empty(),32chunk_size,33is_done: false,34}35}36}3738impl ComputeNode for CallbackSinkNode {39fn name(&self) -> &str {40"sink_batches"41}4243fn update_state(44&mut self,45recv: &mut [PortState],46send: &mut [PortState],47state: &StreamingExecutionState,48) -> PolarsResult<()> {49assert!(recv.len() == 1 && send.is_empty());5051if self.is_done || recv[0] == PortState::Done {52recv[0] = PortState::Done;5354// Flush the last buffer55if !self.buffer.is_empty() && !self.is_done {56let function = self.function.clone();57let df = std::mem::take(&mut self.buffer);5859assert!(60self.chunk_size61.is_some_and(|chunk_size| self.buffer.height() <= chunk_size.into())62);63state.spawn_subphase_task(async move {64polars_io::pl_async::get_runtime()65.spawn_blocking(move || function.call(df))66.await67.unwrap()?;68Ok(())69});70return Ok(());71}72} else {73recv[0] = PortState::Ready;74}7576Ok(())77}7879fn spawn<'env, 's>(80&'env mut self,81scope: &'s TaskScope<'s, 'env>,82recv_ports: &mut [Option<RecvPort<'_>>],83send_ports: &mut [Option<SendPort<'_>>],84_state: &'s StreamingExecutionState,85join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,86) {87assert!(recv_ports.len() == 1 && send_ports.is_empty());88let mut recv = recv_ports[0]89.take()90.unwrap()91.serial_with_maintain_order(self.maintain_order);9293join_handles.push(scope.spawn_task(TaskPriority::High, async move {94while !self.is_done95&& let Ok(m) = recv.recv().await96{97let (df, _, _, consume_token) = m.into_inner();9899// @NOTE: This also performs schema validation.100self.buffer.vstack_mut(&df)?;101102while !self.buffer.is_empty()103&& self104.chunk_size105.is_none_or(|chunk_size| self.buffer.height() >= chunk_size.into())106{107let chunk_size = self.chunk_size.map_or(usize::MAX, Into::into);108109let df;110(df, self.buffer) = self111.buffer112.split_at(self.buffer.height().min(chunk_size) as i64);113114let function = self.function.clone();115let should_stop = polars_io::pl_async::get_runtime()116.spawn_blocking(move || function.call(df))117.await118.unwrap()?;119120if should_stop {121self.is_done = true;122break;123}124}125drop(consume_token);126// Increase the backpressure. Only free up a pipeline when the morsel has been127// processed in its entirety.128}129130Ok(())131}));132}133}134135136