Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/peak_minmax.rs
6939 views
1
use polars_core::frame::DataFrame;
2
use polars_core::prelude::{AnyValue, Column, IntoColumn};
3
use polars_error::PolarsResult;
4
use polars_ops::prelude::peaks;
5
6
use super::ComputeNode;
7
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
8
use crate::async_primitives::wait_group::WaitGroup;
9
use crate::execute::StreamingExecutionState;
10
use crate::graph::PortState;
11
use crate::morsel::{Morsel, MorselSeq, SourceToken};
12
use crate::pipe::{RecvPort, SendPort};
13
14
enum State {
15
/// No morsels seen yet.
16
Start,
17
/// We have seen one morsel. Wait until 1 more to start streaming out data.
18
One(MorselSeq, Column),
19
/// We have seen two morsels. We have saved the last value of 2 morsels ago and the last
20
/// morsel.
21
Two(AnyValue<'static>, MorselSeq, Column),
22
/// No more morsels will be received.
23
Done,
24
}
25
26
pub struct PeakMinMaxNode {
27
state: State,
28
29
/// Is the node the `peak_max`?
30
is_peak_max: bool,
31
}
32
33
impl PeakMinMaxNode {
34
pub fn new(is_peak_max: bool) -> Self {
35
Self {
36
state: State::Start,
37
is_peak_max,
38
}
39
}
40
}
41
42
impl ComputeNode for PeakMinMaxNode {
43
fn name(&self) -> &str {
44
if self.is_peak_max {
45
"peaks_max"
46
} else {
47
"peaks_min"
48
}
49
}
50
51
fn update_state(
52
&mut self,
53
recv: &mut [PortState],
54
send: &mut [PortState],
55
_state: &StreamingExecutionState,
56
) -> PolarsResult<()> {
57
assert!(recv.len() == 1 && send.len() == 1);
58
59
if matches!(self.state, State::Done) {
60
send[0] = PortState::Done;
61
recv[0] = PortState::Done;
62
} else if send[0] == PortState::Done {
63
recv[0] = PortState::Done;
64
self.state = State::Done;
65
} else if recv[0] == PortState::Done {
66
if matches!(self.state, State::Start) {
67
send[0] = PortState::Done;
68
} else {
69
send[0] = PortState::Ready;
70
}
71
} else {
72
recv.swap_with_slice(send);
73
}
74
75
Ok(())
76
}
77
78
fn spawn<'env, 's>(
79
&'env mut self,
80
scope: &'s TaskScope<'s, 'env>,
81
recv_ports: &mut [Option<RecvPort<'_>>],
82
send_ports: &mut [Option<SendPort<'_>>],
83
_state: &'s StreamingExecutionState,
84
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
85
) {
86
assert_eq!(recv_ports.len(), 1);
87
assert_eq!(send_ports.len(), 1);
88
89
let recv = recv_ports[0].take();
90
let mut send = send_ports[0].take().unwrap().serial();
91
92
match recv {
93
// No more morsels to receive. Flush out the remaining data.
94
None => {
95
if matches!(self.state, State::Start) {
96
return;
97
}
98
99
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
100
let (start, seq, prev_column) = match &self.state {
101
State::Start => unreachable!(),
102
State::One(seq, df) => (&AnyValue::Int8(0), *seq, df),
103
State::Two(av, seq, df) => (av, *seq, df),
104
State::Done => unreachable!(),
105
};
106
107
let column = peaks::peak_min_max(
108
prev_column,
109
start,
110
&AnyValue::Int8(0),
111
self.is_peak_max,
112
)?
113
.into_column();
114
let df = DataFrame::new(vec![column]).unwrap();
115
_ = send.send(Morsel::new(df, seq, SourceToken::new())).await;
116
117
self.state = State::Done;
118
Ok(())
119
}));
120
},
121
122
Some(recv) => {
123
let mut recv = recv.serial();
124
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
125
let source_token = SourceToken::new();
126
127
while let Ok(m) = recv.recv().await {
128
let (df, seq, in_source_token, in_wait_token) = m.into_inner();
129
drop(in_wait_token);
130
if df.height() == 0 {
131
continue;
132
}
133
134
assert_eq!(df.width(), 1);
135
let column = &df[0];
136
137
let (start, prev_seq, prev_column) = match &self.state {
138
State::Start => {
139
self.state = State::One(seq, column.clone());
140
continue;
141
},
142
State::One(prev_seq, prev_column) => {
143
(&AnyValue::Int8(0), *prev_seq, prev_column)
144
},
145
State::Two(prev_start, prev_seq, prev_column) => {
146
(prev_start, *prev_seq, prev_column)
147
},
148
State::Done => unreachable!(),
149
};
150
let end = &column.get(0).unwrap();
151
let out = peaks::peak_min_max(prev_column, start, end, self.is_peak_max)?
152
.into_column();
153
154
let wg = WaitGroup::default();
155
let mut m = Morsel::new(
156
DataFrame::new(vec![out]).unwrap(),
157
prev_seq,
158
source_token.clone(),
159
);
160
m.set_consume_token(wg.token());
161
162
if send.send(m).await.is_err() {
163
self.state = State::Done;
164
break;
165
}
166
167
wg.wait().await;
168
if source_token.stop_requested() {
169
in_source_token.stop();
170
}
171
172
let prev_end = prev_column
173
.get(prev_column.len() - 1)
174
.unwrap()
175
.to_physical()
176
.into_static();
177
self.state = State::Two(prev_end, seq, column.clone());
178
}
179
Ok(())
180
}));
181
},
182
}
183
}
184
}
185
186