Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/io_sources/batch.rs
6939 views
1
//! Reads batches from a `dyn Fn`
2
3
use async_trait::async_trait;
4
use polars_core::frame::DataFrame;
5
use polars_core::schema::SchemaRef;
6
use polars_error::{PolarsResult, polars_err};
7
use polars_utils::IdxSize;
8
use polars_utils::pl_str::PlSmallStr;
9
10
use crate::async_executor::{JoinHandle, TaskPriority, spawn};
11
use crate::execute::StreamingExecutionState;
12
use crate::morsel::{Morsel, MorselSeq, SourceToken};
13
use crate::nodes::io_sources::multi_scan::reader_interface::output::{
14
FileReaderOutputRecv, FileReaderOutputSend,
15
};
16
use crate::nodes::io_sources::multi_scan::reader_interface::{
17
BeginReadArgs, FileReader, FileReaderCallbacks,
18
};
19
20
pub mod builder {
21
use std::sync::{Arc, Mutex};
22
23
use polars_utils::pl_str::PlSmallStr;
24
25
use super::BatchFnReader;
26
use crate::execute::StreamingExecutionState;
27
use crate::nodes::io_sources::multi_scan::reader_interface::FileReader;
28
use crate::nodes::io_sources::multi_scan::reader_interface::builder::FileReaderBuilder;
29
use crate::nodes::io_sources::multi_scan::reader_interface::capabilities::ReaderCapabilities;
30
31
pub struct BatchFnReaderBuilder {
32
pub name: PlSmallStr,
33
pub reader: Mutex<Option<BatchFnReader>>,
34
pub execution_state: Mutex<Option<StreamingExecutionState>>,
35
}
36
37
impl FileReaderBuilder for BatchFnReaderBuilder {
38
fn reader_name(&self) -> &str {
39
&self.name
40
}
41
42
fn reader_capabilities(&self) -> ReaderCapabilities {
43
ReaderCapabilities::empty()
44
}
45
46
fn set_execution_state(&self, execution_state: &StreamingExecutionState) {
47
*self.execution_state.lock().unwrap() = Some(execution_state.clone());
48
}
49
50
fn build_file_reader(
51
&self,
52
_source: polars_plan::prelude::ScanSource,
53
_cloud_options: Option<Arc<polars_io::cloud::CloudOptions>>,
54
scan_source_idx: usize,
55
) -> Box<dyn FileReader> {
56
assert_eq!(scan_source_idx, 0);
57
58
let mut reader = self
59
.reader
60
.try_lock()
61
.unwrap()
62
.take()
63
.expect("BatchFnReaderBuilder called more than once");
64
65
reader.execution_state = Some(self.execution_state.lock().unwrap().clone().unwrap());
66
67
Box::new(reader) as Box<dyn FileReader>
68
}
69
}
70
71
impl std::fmt::Debug for BatchFnReaderBuilder {
72
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73
f.write_str("BatchFnReaderBuilder: name: ")?;
74
f.write_str(&self.name)?;
75
76
Ok(())
77
}
78
}
79
}
80
81
pub type GetBatchFn =
82
Box<dyn Fn(&StreamingExecutionState) -> PolarsResult<Option<DataFrame>> + Send + Sync>;
83
84
pub use get_batch_state::GetBatchState;
85
86
mod get_batch_state {
87
use polars_io::pl_async::get_runtime;
88
89
use super::{DataFrame, GetBatchFn, PolarsResult, StreamingExecutionState};
90
91
/// Wraps `GetBatchFn` to support peeking.
92
pub struct GetBatchState {
93
func: GetBatchFn,
94
peek: Option<DataFrame>,
95
}
96
97
impl GetBatchState {
98
pub async fn next(
99
mut slf: Self,
100
execution_state: StreamingExecutionState,
101
) -> PolarsResult<(Self, Option<DataFrame>)> {
102
get_runtime()
103
.spawn_blocking({
104
move || unsafe { slf.next_impl(&execution_state).map(|x| (slf, x)) }
105
})
106
.await
107
.unwrap()
108
}
109
110
pub async fn peek(
111
mut slf: Self,
112
execution_state: StreamingExecutionState,
113
) -> PolarsResult<(Self, Option<DataFrame>)> {
114
get_runtime()
115
.spawn_blocking({
116
move || unsafe { slf.peek_impl(&execution_state).map(|x| (slf, x)) }
117
})
118
.await
119
.unwrap()
120
}
121
122
/// # Safety
123
/// This may deadlock if the caller is an async executor thread, as the `GetBatchFn` may
124
/// be a Python function that re-enters the streaming engine before returning.
125
pub unsafe fn peek_impl(
126
&mut self,
127
state: &StreamingExecutionState,
128
) -> PolarsResult<Option<DataFrame>> {
129
if self.peek.is_none() {
130
self.peek = (self.func)(state)?;
131
}
132
133
Ok(self.peek.clone())
134
}
135
136
/// # Safety
137
/// This may deadlock if the caller is an async executor thread, as the `GetBatchFn` may
138
/// be a Python function that re-enters the streaming engine before returning.
139
unsafe fn next_impl(
140
&mut self,
141
state: &StreamingExecutionState,
142
) -> PolarsResult<Option<DataFrame>> {
143
if let Some(df) = self.peek.take() {
144
Ok(Some(df))
145
} else {
146
(self.func)(state)
147
}
148
}
149
}
150
151
impl From<GetBatchFn> for GetBatchState {
152
fn from(func: GetBatchFn) -> Self {
153
Self { func, peek: None }
154
}
155
}
156
}
157
158
pub struct BatchFnReader {
159
pub name: PlSmallStr,
160
pub output_schema: Option<SchemaRef>,
161
pub get_batch_state: Option<GetBatchState>,
162
pub execution_state: Option<StreamingExecutionState>,
163
pub verbose: bool,
164
}
165
166
#[async_trait]
167
impl FileReader for BatchFnReader {
168
async fn initialize(&mut self) -> PolarsResult<()> {
169
Ok(())
170
}
171
172
fn begin_read(
173
&mut self,
174
args: BeginReadArgs,
175
) -> PolarsResult<(FileReaderOutputRecv, JoinHandle<PolarsResult<()>>)> {
176
let BeginReadArgs {
177
projection: _,
178
row_index: None,
179
pre_slice: None,
180
predicate: None,
181
cast_columns_policy: _,
182
num_pipelines: _,
183
callbacks:
184
FileReaderCallbacks {
185
mut file_schema_tx,
186
n_rows_in_file_tx,
187
row_position_on_end_tx,
188
},
189
} = args
190
else {
191
panic!("unsupported args: {:?}", &args)
192
};
193
194
let execution_state = self.execution_state().clone();
195
196
if file_schema_tx.is_some() && self.output_schema.is_some() {
197
_ = file_schema_tx
198
.take()
199
.unwrap()
200
.try_send(self.output_schema.clone().unwrap());
201
}
202
203
let mut get_batch_state = self
204
.get_batch_state
205
.take()
206
// If this is ever needed we can buffer
207
.expect("unimplemented: BatchFnReader called more than once");
208
209
let verbose = self.verbose;
210
211
if verbose {
212
eprintln!("[BatchFnReader]: name: {}", self.name);
213
}
214
215
let (mut morsel_sender, morsel_rx) = FileReaderOutputSend::new_serial();
216
217
let handle = spawn(TaskPriority::Low, async move {
218
if let Some(mut file_schema_tx) = file_schema_tx {
219
let opt_df;
220
221
(get_batch_state, opt_df) =
222
GetBatchState::peek(get_batch_state, execution_state.clone()).await?;
223
224
_ = file_schema_tx
225
.try_send(opt_df.map(|df| df.schema().clone()).unwrap_or_default())
226
}
227
228
let mut seq: u64 = 0;
229
// Note: We don't use this (it is handled by the bridge). But morsels require a source token.
230
let source_token = SourceToken::new();
231
232
let mut n_rows_seen: usize = 0;
233
234
loop {
235
let opt_df;
236
237
(get_batch_state, opt_df) =
238
GetBatchState::next(get_batch_state, execution_state.clone()).await?;
239
240
let Some(df) = opt_df else {
241
break;
242
};
243
244
n_rows_seen = n_rows_seen.saturating_add(df.height());
245
246
if morsel_sender
247
.send_morsel(Morsel::new(df, MorselSeq::new(seq), source_token.clone()))
248
.await
249
.is_err()
250
{
251
break;
252
};
253
seq = seq.saturating_add(1);
254
}
255
256
if let Some(mut row_position_on_end_tx) = row_position_on_end_tx {
257
let n_rows_seen = IdxSize::try_from(n_rows_seen)
258
.map_err(|_| polars_err!(bigidx, ctx = "batch reader", size = n_rows_seen))?;
259
260
_ = row_position_on_end_tx.try_send(n_rows_seen)
261
}
262
263
if let Some(mut n_rows_in_file_tx) = n_rows_in_file_tx {
264
if verbose {
265
eprintln!("[BatchFnReader]: read to end for full row count");
266
}
267
268
loop {
269
let opt_df;
270
271
(get_batch_state, opt_df) =
272
GetBatchState::next(get_batch_state, execution_state.clone()).await?;
273
274
let Some(df) = opt_df else {
275
break;
276
};
277
278
n_rows_seen = n_rows_seen.saturating_add(df.height());
279
}
280
281
let n_rows_seen = IdxSize::try_from(n_rows_seen)
282
.map_err(|_| polars_err!(bigidx, ctx = "batch reader", size = n_rows_seen))?;
283
284
_ = n_rows_in_file_tx.try_send(n_rows_seen)
285
}
286
287
Ok(())
288
});
289
290
Ok((morsel_rx, handle))
291
}
292
}
293
294
impl BatchFnReader {
295
/// # Panics
296
/// Panics if `self.execution_state` is `None`.
297
fn execution_state(&self) -> &StreamingExecutionState {
298
self.execution_state.as_ref().unwrap()
299
}
300
}
301
302