Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/reduce.rs
8430 views
1
use std::sync::Arc;
2
3
use polars_core::frame::column::ScalarColumn;
4
use polars_core::prelude::Column;
5
use polars_core::schema::{Schema, SchemaExt};
6
use polars_expr::reduce::GroupedReduction;
7
use polars_utils::itertools::Itertools;
8
9
use super::compute_node_prelude::*;
10
use crate::expression::StreamExpr;
11
use crate::morsel::SourceToken;
12
13
enum ReduceState {
14
Sink {
15
selectors: Vec<Vec<StreamExpr>>,
16
reductions: Vec<Box<dyn GroupedReduction>>,
17
},
18
Source(Option<DataFrame>),
19
Done,
20
}
21
22
pub struct ReduceNode {
23
state: ReduceState,
24
output_schema: Arc<Schema>,
25
}
26
27
impl ReduceNode {
28
pub fn new(
29
selectors: Vec<Vec<StreamExpr>>,
30
reductions: Vec<Box<dyn GroupedReduction>>,
31
output_schema: Arc<Schema>,
32
) -> Self {
33
Self {
34
state: ReduceState::Sink {
35
selectors,
36
reductions,
37
},
38
output_schema,
39
}
40
}
41
42
fn spawn_sink<'env, 's>(
43
selectors: &'env [Vec<StreamExpr>],
44
reductions: &'env mut [Box<dyn GroupedReduction>],
45
scope: &'s TaskScope<'s, 'env>,
46
recv: RecvPort<'_>,
47
state: &'s StreamingExecutionState,
48
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
49
) {
50
let parallel_tasks: Vec<_> = recv
51
.parallel()
52
.into_iter()
53
.map(|mut recv| {
54
let mut local_reducers: Vec<_> = reductions
55
.iter()
56
.map(|d| {
57
let mut r = d.new_empty();
58
r.resize(1);
59
r
60
})
61
.collect();
62
63
scope.spawn_task(TaskPriority::High, async move {
64
let mut in_columns = Vec::new();
65
let mut in_column_refs = Vec::new();
66
while let Ok(morsel) = recv.recv().await {
67
for (reducer, selector_set) in local_reducers.iter_mut().zip(selectors) {
68
for selector in selector_set {
69
let col = selector
70
.evaluate(morsel.df(), &state.in_memory_exec_state)
71
.await?;
72
in_columns.push(col);
73
}
74
for c in in_columns.iter() {
75
in_column_refs.push(c);
76
}
77
reducer.update_group(&in_column_refs, 0, morsel.seq().to_u64())?;
78
in_column_refs.clear();
79
in_column_refs =
80
in_column_refs.into_iter().map(|_| unreachable!()).collect(); // Clear lifetimes.
81
in_columns.clear();
82
}
83
}
84
85
PolarsResult::Ok(local_reducers)
86
})
87
})
88
.collect();
89
90
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
91
for task in parallel_tasks {
92
let local_reducers = task.await?;
93
for (r1, r2) in reductions.iter_mut().zip(local_reducers) {
94
r1.resize(1);
95
unsafe {
96
r1.combine_subset(&*r2, &[0], &[0])?;
97
}
98
}
99
}
100
101
Ok(())
102
}));
103
}
104
105
fn spawn_source<'env, 's>(
106
df: &'env mut Option<DataFrame>,
107
scope: &'s TaskScope<'s, 'env>,
108
send: SendPort<'_>,
109
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
110
) {
111
let mut send = send.serial();
112
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
113
let morsel = Morsel::new(df.take().unwrap(), MorselSeq::new(0), SourceToken::new());
114
let _ = send.send(morsel).await;
115
Ok(())
116
}));
117
}
118
}
119
120
impl ComputeNode for ReduceNode {
121
fn name(&self) -> &str {
122
"reduce"
123
}
124
125
fn update_state(
126
&mut self,
127
recv: &mut [PortState],
128
send: &mut [PortState],
129
_state: &StreamingExecutionState,
130
) -> PolarsResult<()> {
131
assert!(recv.len() == 1 && send.len() == 1);
132
133
// State transitions.
134
match &mut self.state {
135
// If the output doesn't want any more data, transition to being done.
136
_ if send[0] == PortState::Done => {
137
self.state = ReduceState::Done;
138
},
139
// Input is done, transition to being a source.
140
ReduceState::Sink { reductions, .. } if matches!(recv[0], PortState::Done) => {
141
let columns = reductions
142
.iter_mut()
143
.zip(self.output_schema.iter_fields())
144
.map(|(r, field)| {
145
r.resize(1);
146
r.finalize().map(|s| {
147
let s = s.with_name(field.name.clone());
148
Column::Scalar(ScalarColumn::unit_scalar_from_series(s))
149
})
150
})
151
.try_collect_vec()?;
152
let out = unsafe { DataFrame::new_unchecked(1, columns) };
153
154
self.state = ReduceState::Source(Some(out));
155
},
156
// We have sent the reduced dataframe, we are done.
157
ReduceState::Source(df) if df.is_none() => {
158
self.state = ReduceState::Done;
159
},
160
// Nothing to change.
161
ReduceState::Done | ReduceState::Sink { .. } | ReduceState::Source(_) => {},
162
}
163
164
// Communicate our state.
165
match &self.state {
166
ReduceState::Sink { .. } => {
167
send[0] = PortState::Blocked;
168
recv[0] = PortState::Ready;
169
},
170
ReduceState::Source(..) => {
171
recv[0] = PortState::Done;
172
send[0] = PortState::Ready;
173
},
174
ReduceState::Done => {
175
recv[0] = PortState::Done;
176
send[0] = PortState::Done;
177
},
178
}
179
Ok(())
180
}
181
182
fn spawn<'env, 's>(
183
&'env mut self,
184
scope: &'s TaskScope<'s, 'env>,
185
recv_ports: &mut [Option<RecvPort<'_>>],
186
send_ports: &mut [Option<SendPort<'_>>],
187
state: &'s StreamingExecutionState,
188
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
189
) {
190
assert!(send_ports.len() == 1 && recv_ports.len() == 1);
191
match &mut self.state {
192
ReduceState::Sink {
193
selectors,
194
reductions,
195
} => {
196
assert!(send_ports[0].is_none());
197
let recv_port = recv_ports[0].take().unwrap();
198
Self::spawn_sink(selectors, reductions, scope, recv_port, state, join_handles)
199
},
200
ReduceState::Source(df) => {
201
assert!(recv_ports[0].is_none());
202
let send_port = send_ports[0].take().unwrap();
203
Self::spawn_source(df, scope, send_port, join_handles)
204
},
205
ReduceState::Done => unreachable!(),
206
}
207
}
208
}
209
210