Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/callback_sink.rs
7884 views
1
use std::num::NonZeroUsize;
2
3
use polars_core::frame::DataFrame;
4
use polars_error::PolarsResult;
5
use polars_plan::prelude::PlanCallback;
6
7
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
8
use crate::execute::StreamingExecutionState;
9
use crate::graph::PortState;
10
use crate::nodes::ComputeNode;
11
use crate::pipe::{RecvPort, SendPort};
12
13
pub struct CallbackSinkNode {
14
function: PlanCallback<DataFrame, bool>,
15
maintain_order: bool,
16
17
buffer: DataFrame,
18
chunk_size: Option<NonZeroUsize>,
19
is_done: bool,
20
}
21
22
impl CallbackSinkNode {
23
pub fn new(
24
function: PlanCallback<DataFrame, bool>,
25
maintain_order: bool,
26
chunk_size: Option<NonZeroUsize>,
27
) -> Self {
28
Self {
29
function,
30
maintain_order,
31
32
buffer: DataFrame::empty(),
33
chunk_size,
34
is_done: false,
35
}
36
}
37
}
38
39
impl ComputeNode for CallbackSinkNode {
40
fn name(&self) -> &str {
41
"sink_batches"
42
}
43
44
fn update_state(
45
&mut self,
46
recv: &mut [PortState],
47
send: &mut [PortState],
48
state: &StreamingExecutionState,
49
) -> PolarsResult<()> {
50
assert!(recv.len() == 1 && send.is_empty());
51
52
if self.is_done || recv[0] == PortState::Done {
53
recv[0] = PortState::Done;
54
55
// Flush the last buffer
56
if !self.buffer.is_empty() && !self.is_done {
57
let function = self.function.clone();
58
let df = std::mem::take(&mut self.buffer);
59
60
assert!(
61
self.chunk_size
62
.is_some_and(|chunk_size| self.buffer.height() <= chunk_size.into())
63
);
64
state.spawn_subphase_task(async move {
65
polars_io::pl_async::get_runtime()
66
.spawn_blocking(move || function.call(df))
67
.await
68
.unwrap()?;
69
Ok(())
70
});
71
return Ok(());
72
}
73
} else {
74
recv[0] = PortState::Ready;
75
}
76
77
Ok(())
78
}
79
80
fn spawn<'env, 's>(
81
&'env mut self,
82
scope: &'s TaskScope<'s, 'env>,
83
recv_ports: &mut [Option<RecvPort<'_>>],
84
send_ports: &mut [Option<SendPort<'_>>],
85
_state: &'s StreamingExecutionState,
86
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
87
) {
88
assert!(recv_ports.len() == 1 && send_ports.is_empty());
89
let mut recv = recv_ports[0]
90
.take()
91
.unwrap()
92
.serial_with_maintain_order(self.maintain_order);
93
94
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
95
while !self.is_done
96
&& let Ok(m) = recv.recv().await
97
{
98
let (df, _, _, consume_token) = m.into_inner();
99
100
// @NOTE: This also performs schema validation.
101
self.buffer.vstack_mut(&df)?;
102
103
while !self.buffer.is_empty()
104
&& self
105
.chunk_size
106
.is_none_or(|chunk_size| self.buffer.height() >= chunk_size.into())
107
{
108
let chunk_size = self.chunk_size.map_or(usize::MAX, Into::into);
109
110
let df;
111
(df, self.buffer) = self
112
.buffer
113
.split_at(self.buffer.height().min(chunk_size) as i64);
114
115
let function = self.function.clone();
116
let should_stop = polars_io::pl_async::get_runtime()
117
.spawn_blocking(move || function.call(df))
118
.await
119
.unwrap()?;
120
121
if should_stop {
122
self.is_done = true;
123
break;
124
}
125
}
126
drop(consume_token);
127
// Increase the backpressure. Only free up a pipeline when the morsel has been
128
// processed in its entirety.
129
}
130
131
Ok(())
132
}));
133
}
134
}
135
136