Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/shift.rs
8420 views
1
use std::collections::VecDeque;
2
use std::sync::Arc;
3
4
use polars_core::prelude::*;
5
use polars_core::schema::Schema;
6
7
use super::compute_node_prelude::*;
8
use crate::async_primitives::wait_group::WaitGroup;
9
use crate::morsel::{SourceToken, get_ideal_morsel_size};
10
use crate::nodes::in_memory_sink::InMemorySinkNode;
11
use crate::pipe::PortReceiver;
12
13
#[allow(private_interfaces)]
14
pub enum ShiftNode {
15
GatheringParams {
16
offset: InMemorySinkNode,
17
fill: Option<InMemorySinkNode>,
18
output_schema: Arc<Schema>,
19
},
20
Shifting(ShiftState),
21
Done,
22
}
23
24
struct ShiftState {
25
offset: i64,
26
rows_received: usize,
27
rows_sent: usize,
28
buffer: VecDeque<DataFrame>,
29
fill: DataFrame,
30
seq: MorselSeq,
31
}
32
33
impl ShiftState {
34
async fn shift_positive(
35
&mut self,
36
mut recv: Option<PortReceiver>,
37
mut send: PortSender,
38
) -> PolarsResult<()> {
39
let mut source_token = SourceToken::new();
40
let wait_group = WaitGroup::default();
41
42
while recv.is_some() || self.rows_received != self.rows_sent {
43
// Try to get more data if necessary.
44
if self.rows_received == self.rows_sent {
45
if let Some(r) = &mut recv {
46
let Ok(morsel) = r.recv().await else { break };
47
source_token = morsel.source_token().clone();
48
if morsel.df().height() == 0 {
49
continue;
50
}
51
self.rows_received += morsel.df().height();
52
self.buffer.push_back(morsel.into_df());
53
}
54
}
55
56
// Send along a morsel.
57
let df;
58
if self.rows_sent < self.offset as usize {
59
let len = self.rows_received.min(self.offset as usize) - self.rows_sent;
60
df = self.fill.new_from_index(0, len);
61
} else {
62
let src = self.buffer.front_mut().unwrap();
63
let len = self.rows_received - self.rows_sent;
64
(df, *src) = src.split_at(len as i64);
65
if src.height() == 0 {
66
self.buffer.pop_front();
67
}
68
};
69
self.rows_sent += df.height();
70
71
let mut morsel = Morsel::new(df, self.seq, source_token.clone());
72
self.seq = self.seq.successor();
73
morsel.set_consume_token(wait_group.token());
74
if send.send(morsel).await.is_err() {
75
break;
76
}
77
wait_group.wait().await;
78
}
79
80
Ok(())
81
}
82
83
async fn shift_negative(
84
&mut self,
85
mut recv: PortReceiver,
86
mut send: PortSender,
87
) -> PolarsResult<()> {
88
let shift = self.offset.unsigned_abs() as usize;
89
90
while let Ok(mut morsel) = recv.recv().await {
91
let shift_needed = shift.saturating_sub(self.rows_received);
92
self.rows_received += morsel.df().height();
93
if shift_needed > 0 {
94
morsel =
95
morsel.map(|df| df.slice(shift_needed.min(df.height()) as i64, df.height()));
96
}
97
if morsel.df().height() == 0 {
98
continue;
99
}
100
101
morsel.set_seq(self.seq);
102
self.seq = self.seq.successor();
103
self.rows_sent += morsel.df().height();
104
if send.send(morsel).await.is_err() {
105
break;
106
}
107
}
108
109
Ok(())
110
}
111
112
async fn flush_negative(
113
&mut self,
114
mut send: PortSender,
115
state: &StreamingExecutionState,
116
) -> PolarsResult<()> {
117
let source_token = SourceToken::new();
118
let wait_group = WaitGroup::default();
119
120
let total_len = self.rows_received - self.rows_sent;
121
let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1);
122
let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);
123
let morsel_size = total_len.div_ceil(morsel_count).max(1);
124
125
while self.rows_sent != self.rows_received {
126
let len = morsel_size.min(self.rows_received - self.rows_sent);
127
let df = self.fill.new_from_index(0, len);
128
self.rows_sent += len;
129
130
let mut morsel = Morsel::new(df, self.seq, source_token.clone());
131
self.seq = self.seq.successor();
132
morsel.set_consume_token(wait_group.token());
133
if send.send(morsel).await.is_err() {
134
break;
135
}
136
wait_group.wait().await;
137
}
138
139
Ok(())
140
}
141
}
142
143
impl ShiftNode {
144
pub fn new(output_schema: Arc<Schema>, offset_schema: Arc<Schema>, with_fill: bool) -> Self {
145
assert!(offset_schema.len() == 1);
146
Self::GatheringParams {
147
offset: InMemorySinkNode::new(offset_schema),
148
fill: with_fill.then(|| InMemorySinkNode::new(output_schema.clone())),
149
output_schema,
150
}
151
}
152
}
153
154
impl ComputeNode for ShiftNode {
155
fn name(&self) -> &str {
156
"shift"
157
}
158
159
fn update_state(
160
&mut self,
161
recv: &mut [PortState],
162
send: &mut [PortState],
163
state: &StreamingExecutionState,
164
) -> PolarsResult<()> {
165
assert!(recv.len() <= 3 && send.len() == 1);
166
167
// Are we done?
168
if send[0] == PortState::Done {
169
*self = Self::Done;
170
} else if recv[0] == PortState::Done {
171
if let Self::Shifting(shift_state) = self {
172
if shift_state.rows_sent == shift_state.rows_received {
173
*self = Self::Done;
174
}
175
}
176
}
177
178
// Do we have the parameters to start shifting?
179
if recv[1..].iter().all(|p| *p == PortState::Done) {
180
if let Self::GatheringParams {
181
offset,
182
fill,
183
output_schema,
184
} = self
185
{
186
let offset_frame = offset.get_output()?.unwrap();
187
polars_ensure!(offset_frame.height() == 1, ComputeError: "got more than one value for 'n' in shift");
188
let offset_item = offset_frame.columns()[0].get(0)?;
189
let offset = if offset_item.is_null() {
190
polars_warn!(
191
Deprecation, // @2.0
192
"shift value 'n' is null, which currently returns a column of null values. This will become an error in the future.",
193
);
194
// @2.0: Currently we still require the entire output to become null
195
// if the shift is null, simulate this with an infinite negative shift.
196
*fill = None;
197
i64::MIN
198
} else {
199
offset_item.extract::<i64>().ok_or_else(
200
|| polars_err!(ComputeError: "invalid value of 'n' in shift: {:?}", offset_item),
201
)?
202
};
203
204
let fill_frame = if let Some(fill) = fill {
205
fill.get_output()?.unwrap()
206
} else {
207
DataFrame::empty_with_schema(output_schema)
208
};
209
210
*self = Self::Shifting(ShiftState {
211
offset,
212
rows_received: 0,
213
rows_sent: 0,
214
buffer: VecDeque::new(),
215
fill: fill_frame,
216
seq: MorselSeq::default(),
217
})
218
}
219
}
220
221
match self {
222
Self::GatheringParams { offset, fill, .. } => {
223
offset.update_state(&mut recv[1..2], &mut [], state)?;
224
if let Some(fill) = fill {
225
fill.update_state(&mut recv[2..3], &mut [], state)?;
226
}
227
recv[0] = PortState::Blocked;
228
send[0] = PortState::Blocked;
229
},
230
Self::Shifting(shift_state) => {
231
if recv[0] == PortState::Done && shift_state.rows_sent < shift_state.rows_received {
232
send[0] = PortState::Ready;
233
} else {
234
recv[..1].swap_with_slice(send);
235
}
236
recv[1..].fill(PortState::Done);
237
},
238
Self::Done => {
239
recv.fill(PortState::Done);
240
send[0] = PortState::Done;
241
},
242
}
243
Ok(())
244
}
245
246
fn spawn<'env, 's>(
247
&'env mut self,
248
scope: &'s TaskScope<'s, 'env>,
249
recv_ports: &mut [Option<RecvPort<'_>>],
250
send_ports: &mut [Option<SendPort<'_>>],
251
state: &'s StreamingExecutionState,
252
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
253
) {
254
assert!(recv_ports.len() <= 3 && send_ports.len() == 1);
255
match self {
256
Self::GatheringParams {
257
offset,
258
fill,
259
output_schema: _,
260
} => {
261
assert!(recv_ports[0].is_none());
262
assert!(send_ports[0].is_none());
263
if recv_ports[1].is_some() {
264
offset.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);
265
}
266
if matches!(recv_ports.get(2), Some(Some(_))) {
267
fill.as_mut().unwrap().spawn(
268
scope,
269
&mut recv_ports[2..3],
270
&mut [],
271
state,
272
join_handles,
273
);
274
}
275
},
276
Self::Shifting(shift_state) => {
277
assert!(recv_ports[1..].iter().all(|p| p.is_none()));
278
let recv = recv_ports[0].take().map(|p| p.serial());
279
let send = send_ports[0].take().unwrap().serial();
280
join_handles.push(if shift_state.offset >= 0 {
281
scope.spawn_task(TaskPriority::High, shift_state.shift_positive(recv, send))
282
} else if let Some(r) = recv {
283
scope.spawn_task(TaskPriority::High, shift_state.shift_negative(r, send))
284
} else {
285
scope.spawn_task(TaskPriority::High, shift_state.flush_negative(send, state))
286
});
287
},
288
Self::Done => unreachable!(),
289
}
290
}
291
}
292
293