Path: blob/main/crates/polars-stream/src/nodes/io_sources/batch.rs
6939 views
//! Reads batches from a `dyn Fn`12use async_trait::async_trait;3use polars_core::frame::DataFrame;4use polars_core::schema::SchemaRef;5use polars_error::{PolarsResult, polars_err};6use polars_utils::IdxSize;7use polars_utils::pl_str::PlSmallStr;89use crate::async_executor::{JoinHandle, TaskPriority, spawn};10use crate::execute::StreamingExecutionState;11use crate::morsel::{Morsel, MorselSeq, SourceToken};12use crate::nodes::io_sources::multi_scan::reader_interface::output::{13FileReaderOutputRecv, FileReaderOutputSend,14};15use crate::nodes::io_sources::multi_scan::reader_interface::{16BeginReadArgs, FileReader, FileReaderCallbacks,17};1819pub mod builder {20use std::sync::{Arc, Mutex};2122use polars_utils::pl_str::PlSmallStr;2324use super::BatchFnReader;25use crate::execute::StreamingExecutionState;26use crate::nodes::io_sources::multi_scan::reader_interface::FileReader;27use crate::nodes::io_sources::multi_scan::reader_interface::builder::FileReaderBuilder;28use crate::nodes::io_sources::multi_scan::reader_interface::capabilities::ReaderCapabilities;2930pub struct BatchFnReaderBuilder {31pub name: PlSmallStr,32pub reader: Mutex<Option<BatchFnReader>>,33pub execution_state: Mutex<Option<StreamingExecutionState>>,34}3536impl FileReaderBuilder for BatchFnReaderBuilder {37fn reader_name(&self) -> &str {38&self.name39}4041fn reader_capabilities(&self) -> ReaderCapabilities {42ReaderCapabilities::empty()43}4445fn set_execution_state(&self, execution_state: &StreamingExecutionState) {46*self.execution_state.lock().unwrap() = Some(execution_state.clone());47}4849fn build_file_reader(50&self,51_source: polars_plan::prelude::ScanSource,52_cloud_options: Option<Arc<polars_io::cloud::CloudOptions>>,53scan_source_idx: usize,54) -> Box<dyn FileReader> {55assert_eq!(scan_source_idx, 0);5657let mut reader = self58.reader59.try_lock()60.unwrap()61.take()62.expect("BatchFnReaderBuilder called more than once");6364reader.execution_state = Some(self.execution_state.lock().unwrap().clone().unwrap());6566Box::new(reader) as Box<dyn FileReader>67}68}6970impl std::fmt::Debug for BatchFnReaderBuilder {71fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {72f.write_str("BatchFnReaderBuilder: name: ")?;73f.write_str(&self.name)?;7475Ok(())76}77}78}7980pub type GetBatchFn =81Box<dyn Fn(&StreamingExecutionState) -> PolarsResult<Option<DataFrame>> + Send + Sync>;8283pub use get_batch_state::GetBatchState;8485mod get_batch_state {86use polars_io::pl_async::get_runtime;8788use super::{DataFrame, GetBatchFn, PolarsResult, StreamingExecutionState};8990/// Wraps `GetBatchFn` to support peeking.91pub struct GetBatchState {92func: GetBatchFn,93peek: Option<DataFrame>,94}9596impl GetBatchState {97pub async fn next(98mut slf: Self,99execution_state: StreamingExecutionState,100) -> PolarsResult<(Self, Option<DataFrame>)> {101get_runtime()102.spawn_blocking({103move || unsafe { slf.next_impl(&execution_state).map(|x| (slf, x)) }104})105.await106.unwrap()107}108109pub async fn peek(110mut slf: Self,111execution_state: StreamingExecutionState,112) -> PolarsResult<(Self, Option<DataFrame>)> {113get_runtime()114.spawn_blocking({115move || unsafe { slf.peek_impl(&execution_state).map(|x| (slf, x)) }116})117.await118.unwrap()119}120121/// # Safety122/// This may deadlock if the caller is an async executor thread, as the `GetBatchFn` may123/// be a Python function that re-enters the streaming engine before returning.124pub unsafe fn peek_impl(125&mut self,126state: &StreamingExecutionState,127) -> PolarsResult<Option<DataFrame>> {128if self.peek.is_none() {129self.peek = (self.func)(state)?;130}131132Ok(self.peek.clone())133}134135/// # Safety136/// This may deadlock if the caller is an async executor thread, as the `GetBatchFn` may137/// be a Python function that re-enters the streaming engine before returning.138unsafe fn next_impl(139&mut self,140state: &StreamingExecutionState,141) -> PolarsResult<Option<DataFrame>> {142if let Some(df) = self.peek.take() {143Ok(Some(df))144} else {145(self.func)(state)146}147}148}149150impl From<GetBatchFn> for GetBatchState {151fn from(func: GetBatchFn) -> Self {152Self { func, peek: None }153}154}155}156157pub struct BatchFnReader {158pub name: PlSmallStr,159pub output_schema: Option<SchemaRef>,160pub get_batch_state: Option<GetBatchState>,161pub execution_state: Option<StreamingExecutionState>,162pub verbose: bool,163}164165#[async_trait]166impl FileReader for BatchFnReader {167async fn initialize(&mut self) -> PolarsResult<()> {168Ok(())169}170171fn begin_read(172&mut self,173args: BeginReadArgs,174) -> PolarsResult<(FileReaderOutputRecv, JoinHandle<PolarsResult<()>>)> {175let BeginReadArgs {176projection: _,177row_index: None,178pre_slice: None,179predicate: None,180cast_columns_policy: _,181num_pipelines: _,182callbacks:183FileReaderCallbacks {184mut file_schema_tx,185n_rows_in_file_tx,186row_position_on_end_tx,187},188} = args189else {190panic!("unsupported args: {:?}", &args)191};192193let execution_state = self.execution_state().clone();194195if file_schema_tx.is_some() && self.output_schema.is_some() {196_ = file_schema_tx197.take()198.unwrap()199.try_send(self.output_schema.clone().unwrap());200}201202let mut get_batch_state = self203.get_batch_state204.take()205// If this is ever needed we can buffer206.expect("unimplemented: BatchFnReader called more than once");207208let verbose = self.verbose;209210if verbose {211eprintln!("[BatchFnReader]: name: {}", self.name);212}213214let (mut morsel_sender, morsel_rx) = FileReaderOutputSend::new_serial();215216let handle = spawn(TaskPriority::Low, async move {217if let Some(mut file_schema_tx) = file_schema_tx {218let opt_df;219220(get_batch_state, opt_df) =221GetBatchState::peek(get_batch_state, execution_state.clone()).await?;222223_ = file_schema_tx224.try_send(opt_df.map(|df| df.schema().clone()).unwrap_or_default())225}226227let mut seq: u64 = 0;228// Note: We don't use this (it is handled by the bridge). But morsels require a source token.229let source_token = SourceToken::new();230231let mut n_rows_seen: usize = 0;232233loop {234let opt_df;235236(get_batch_state, opt_df) =237GetBatchState::next(get_batch_state, execution_state.clone()).await?;238239let Some(df) = opt_df else {240break;241};242243n_rows_seen = n_rows_seen.saturating_add(df.height());244245if morsel_sender246.send_morsel(Morsel::new(df, MorselSeq::new(seq), source_token.clone()))247.await248.is_err()249{250break;251};252seq = seq.saturating_add(1);253}254255if let Some(mut row_position_on_end_tx) = row_position_on_end_tx {256let n_rows_seen = IdxSize::try_from(n_rows_seen)257.map_err(|_| polars_err!(bigidx, ctx = "batch reader", size = n_rows_seen))?;258259_ = row_position_on_end_tx.try_send(n_rows_seen)260}261262if let Some(mut n_rows_in_file_tx) = n_rows_in_file_tx {263if verbose {264eprintln!("[BatchFnReader]: read to end for full row count");265}266267loop {268let opt_df;269270(get_batch_state, opt_df) =271GetBatchState::next(get_batch_state, execution_state.clone()).await?;272273let Some(df) = opt_df else {274break;275};276277n_rows_seen = n_rows_seen.saturating_add(df.height());278}279280let n_rows_seen = IdxSize::try_from(n_rows_seen)281.map_err(|_| polars_err!(bigidx, ctx = "batch reader", size = n_rows_seen))?;282283_ = n_rows_in_file_tx.try_send(n_rows_seen)284}285286Ok(())287});288289Ok((morsel_rx, handle))290}291}292293impl BatchFnReader {294/// # Panics295/// Panics if `self.execution_state` is `None`.296fn execution_state(&self) -> &StreamingExecutionState {297self.execution_state.as_ref().unwrap()298}299}300301302