Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/reduce.rs
6939 views
1
use std::sync::Arc;
2
3
use polars_core::frame::column::ScalarColumn;
4
use polars_core::prelude::Column;
5
use polars_core::schema::{Schema, SchemaExt};
6
use polars_expr::reduce::GroupedReduction;
7
use polars_utils::itertools::Itertools;
8
9
use super::compute_node_prelude::*;
10
use crate::expression::StreamExpr;
11
use crate::morsel::SourceToken;
12
13
enum ReduceState {
14
Sink {
15
selectors: Vec<StreamExpr>,
16
reductions: Vec<Box<dyn GroupedReduction>>,
17
},
18
Source(Option<DataFrame>),
19
Done,
20
}
21
22
pub struct ReduceNode {
23
state: ReduceState,
24
output_schema: Arc<Schema>,
25
}
26
27
impl ReduceNode {
28
pub fn new(
29
selectors: Vec<StreamExpr>,
30
reductions: Vec<Box<dyn GroupedReduction>>,
31
output_schema: Arc<Schema>,
32
) -> Self {
33
Self {
34
state: ReduceState::Sink {
35
selectors,
36
reductions,
37
},
38
output_schema,
39
}
40
}
41
42
fn spawn_sink<'env, 's>(
43
selectors: &'env [StreamExpr],
44
reductions: &'env mut [Box<dyn GroupedReduction>],
45
scope: &'s TaskScope<'s, 'env>,
46
recv: RecvPort<'_>,
47
state: &'s StreamingExecutionState,
48
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
49
) {
50
let parallel_tasks: Vec<_> = recv
51
.parallel()
52
.into_iter()
53
.map(|mut recv| {
54
let mut local_reducers: Vec<_> = reductions
55
.iter()
56
.map(|d| {
57
let mut r = d.new_empty();
58
r.resize(1);
59
r
60
})
61
.collect();
62
63
scope.spawn_task(TaskPriority::High, async move {
64
while let Ok(morsel) = recv.recv().await {
65
for (reducer, selector) in local_reducers.iter_mut().zip(selectors) {
66
let input = selector
67
.evaluate(morsel.df(), &state.in_memory_exec_state)
68
.await?;
69
reducer.update_group(&input, 0, morsel.seq().to_u64())?;
70
}
71
}
72
73
PolarsResult::Ok(local_reducers)
74
})
75
})
76
.collect();
77
78
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
79
for task in parallel_tasks {
80
let local_reducers = task.await?;
81
for (r1, r2) in reductions.iter_mut().zip(local_reducers) {
82
r1.resize(1);
83
unsafe {
84
r1.combine_subset(&*r2, &[0], &[0])?;
85
}
86
}
87
}
88
89
Ok(())
90
}));
91
}
92
93
fn spawn_source<'env, 's>(
94
df: &'env mut Option<DataFrame>,
95
scope: &'s TaskScope<'s, 'env>,
96
send: SendPort<'_>,
97
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
98
) {
99
let mut send = send.serial();
100
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
101
let morsel = Morsel::new(df.take().unwrap(), MorselSeq::new(0), SourceToken::new());
102
let _ = send.send(morsel).await;
103
Ok(())
104
}));
105
}
106
}
107
108
impl ComputeNode for ReduceNode {
109
fn name(&self) -> &str {
110
"reduce"
111
}
112
113
fn update_state(
114
&mut self,
115
recv: &mut [PortState],
116
send: &mut [PortState],
117
_state: &StreamingExecutionState,
118
) -> PolarsResult<()> {
119
assert!(recv.len() == 1 && send.len() == 1);
120
121
// State transitions.
122
match &mut self.state {
123
// If the output doesn't want any more data, transition to being done.
124
_ if send[0] == PortState::Done => {
125
self.state = ReduceState::Done;
126
},
127
// Input is done, transition to being a source.
128
ReduceState::Sink { reductions, .. } if matches!(recv[0], PortState::Done) => {
129
let columns = reductions
130
.iter_mut()
131
.zip(self.output_schema.iter_fields())
132
.map(|(r, field)| {
133
r.resize(1);
134
r.finalize().map(|s| {
135
let s = s.with_name(field.name.clone()).cast(&field.dtype).unwrap();
136
Column::Scalar(ScalarColumn::unit_scalar_from_series(s))
137
})
138
})
139
.try_collect_vec()?;
140
let out = DataFrame::new(columns).unwrap();
141
142
self.state = ReduceState::Source(Some(out));
143
},
144
// We have sent the reduced dataframe, we are done.
145
ReduceState::Source(df) if df.is_none() => {
146
self.state = ReduceState::Done;
147
},
148
// Nothing to change.
149
ReduceState::Done | ReduceState::Sink { .. } | ReduceState::Source(_) => {},
150
}
151
152
// Communicate our state.
153
match &self.state {
154
ReduceState::Sink { .. } => {
155
send[0] = PortState::Blocked;
156
recv[0] = PortState::Ready;
157
},
158
ReduceState::Source(..) => {
159
recv[0] = PortState::Done;
160
send[0] = PortState::Ready;
161
},
162
ReduceState::Done => {
163
recv[0] = PortState::Done;
164
send[0] = PortState::Done;
165
},
166
}
167
Ok(())
168
}
169
170
fn spawn<'env, 's>(
171
&'env mut self,
172
scope: &'s TaskScope<'s, 'env>,
173
recv_ports: &mut [Option<RecvPort<'_>>],
174
send_ports: &mut [Option<SendPort<'_>>],
175
state: &'s StreamingExecutionState,
176
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
177
) {
178
assert!(send_ports.len() == 1 && recv_ports.len() == 1);
179
match &mut self.state {
180
ReduceState::Sink {
181
selectors,
182
reductions,
183
} => {
184
assert!(send_ports[0].is_none());
185
let recv_port = recv_ports[0].take().unwrap();
186
Self::spawn_sink(selectors, reductions, scope, recv_port, state, join_handles)
187
},
188
ReduceState::Source(df) => {
189
assert!(recv_ports[0].is_none());
190
let send_port = send_ports[0].take().unwrap();
191
Self::spawn_source(df, scope, send_port, join_handles)
192
},
193
ReduceState::Done => unreachable!(),
194
}
195
}
196
}
197
198