Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/shift.rs
6939 views
1
use std::collections::VecDeque;
2
use std::sync::Arc;
3
4
use polars_core::prelude::*;
5
use polars_core::schema::Schema;
6
7
use super::compute_node_prelude::*;
8
use crate::async_primitives::connector::{Receiver, Sender};
9
use crate::async_primitives::wait_group::WaitGroup;
10
use crate::morsel::{SourceToken, get_ideal_morsel_size};
11
use crate::nodes::in_memory_sink::InMemorySinkNode;
12
13
#[allow(private_interfaces)]
14
pub enum ShiftNode {
15
GatheringParams {
16
offset: InMemorySinkNode,
17
fill: Option<InMemorySinkNode>,
18
output_schema: Arc<Schema>,
19
},
20
Shifting(ShiftState),
21
Done,
22
}
23
24
struct ShiftState {
25
offset: i64,
26
rows_received: usize,
27
rows_sent: usize,
28
buffer: VecDeque<DataFrame>,
29
fill: DataFrame,
30
seq: MorselSeq,
31
}
32
33
impl ShiftState {
34
async fn shift_positive(
35
&mut self,
36
mut recv: Option<Receiver<Morsel>>,
37
mut send: Sender<Morsel>,
38
) -> PolarsResult<()> {
39
let mut source_token = SourceToken::new();
40
let wait_group = WaitGroup::default();
41
42
while recv.is_some() || self.rows_received != self.rows_sent {
43
// Try to get more data if necessary.
44
if self.rows_received == self.rows_sent {
45
if let Some(r) = &mut recv {
46
let Ok(morsel) = r.recv().await else { break };
47
source_token = morsel.source_token().clone();
48
if morsel.df().is_empty() {
49
continue;
50
}
51
self.rows_received += morsel.df().height();
52
self.buffer.push_back(morsel.into_df());
53
}
54
}
55
56
// Send along a morsel.
57
let df;
58
if self.rows_sent < self.offset as usize {
59
let len = self.rows_received.min(self.offset as usize) - self.rows_sent;
60
df = self.fill.new_from_index(0, len);
61
} else {
62
let src = self.buffer.front_mut().unwrap();
63
let len = self.rows_received - self.rows_sent;
64
(df, *src) = src.split_at(len as i64);
65
if src.is_empty() {
66
self.buffer.pop_front();
67
}
68
};
69
self.rows_sent += df.height();
70
71
let mut morsel = Morsel::new(df, self.seq, source_token.clone());
72
self.seq = self.seq.successor();
73
morsel.set_consume_token(wait_group.token());
74
if send.send(morsel).await.is_err() {
75
break;
76
}
77
wait_group.wait().await;
78
if source_token.stop_requested() {
79
break;
80
}
81
}
82
83
Ok(())
84
}
85
86
async fn shift_negative(
87
&mut self,
88
mut recv: Receiver<Morsel>,
89
mut send: Sender<Morsel>,
90
) -> PolarsResult<()> {
91
let shift = self.offset.unsigned_abs() as usize;
92
93
while let Ok(mut morsel) = recv.recv().await {
94
let shift_needed = shift.saturating_sub(self.rows_received);
95
self.rows_received += morsel.df().height();
96
if shift_needed > 0 {
97
morsel =
98
morsel.map(|df| df.slice(shift_needed.min(df.height()) as i64, df.height()));
99
}
100
if morsel.df().is_empty() {
101
continue;
102
}
103
104
morsel.set_seq(self.seq);
105
self.seq = self.seq.successor();
106
self.rows_sent += morsel.df().height();
107
if send.send(morsel).await.is_err() {
108
break;
109
}
110
}
111
112
Ok(())
113
}
114
115
async fn flush_negative(
116
&mut self,
117
mut send: Sender<Morsel>,
118
state: &StreamingExecutionState,
119
) -> PolarsResult<()> {
120
let source_token = SourceToken::new();
121
let wait_group = WaitGroup::default();
122
123
let total_len = self.rows_received - self.rows_sent;
124
let ideal_morsel_count = (total_len / get_ideal_morsel_size()).max(1);
125
let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);
126
let morsel_size = total_len.div_ceil(morsel_count).max(1);
127
128
while self.rows_sent != self.rows_received {
129
let len = morsel_size.min(self.rows_received - self.rows_sent);
130
let df = self.fill.new_from_index(0, len);
131
self.rows_sent += len;
132
133
let mut morsel = Morsel::new(df, self.seq, source_token.clone());
134
self.seq = self.seq.successor();
135
morsel.set_consume_token(wait_group.token());
136
if send.send(morsel).await.is_err() {
137
break;
138
}
139
wait_group.wait().await;
140
if source_token.stop_requested() {
141
break;
142
}
143
}
144
145
Ok(())
146
}
147
}
148
149
impl ShiftNode {
150
pub fn new(output_schema: Arc<Schema>, offset_schema: Arc<Schema>, with_fill: bool) -> Self {
151
assert!(offset_schema.len() == 1);
152
Self::GatheringParams {
153
offset: InMemorySinkNode::new(offset_schema),
154
fill: with_fill.then(|| InMemorySinkNode::new(output_schema.clone())),
155
output_schema,
156
}
157
}
158
}
159
160
impl ComputeNode for ShiftNode {
161
fn name(&self) -> &str {
162
"shift"
163
}
164
165
fn update_state(
166
&mut self,
167
recv: &mut [PortState],
168
send: &mut [PortState],
169
state: &StreamingExecutionState,
170
) -> PolarsResult<()> {
171
assert!(recv.len() <= 3 && send.len() == 1);
172
173
// Are we done?
174
if recv[0] == PortState::Done {
175
if let Self::Shifting(shift_state) = self {
176
if shift_state.rows_sent == shift_state.rows_received {
177
*self = Self::Done;
178
}
179
}
180
}
181
182
// Do we have the parameters to start shifting?
183
if recv[1..].iter().all(|p| *p == PortState::Done) {
184
if let Self::GatheringParams {
185
offset,
186
fill,
187
output_schema,
188
} = self
189
{
190
let offset_frame = offset.get_output()?.unwrap();
191
polars_ensure!(offset_frame.height() == 1, ComputeError: "got more than one value for 'n' in shift");
192
let offset_item = offset_frame.get_columns()[0].get(0)?;
193
let offset = if offset_item.is_null() {
194
polars_warn!(
195
Deprecation, // @2.0
196
"shift value 'n' is null, which currently returns a column of null values. This will become an error in the future.",
197
);
198
// @2.0: Currently we still require the entire output to become null
199
// if the shift is null, simulate this with an infinite negative shift.
200
*fill = None;
201
i64::MIN
202
} else {
203
offset_item.extract::<i64>().ok_or_else(
204
|| polars_err!(ComputeError: "invalid value of 'n' in shift: {:?}", offset_item),
205
)?
206
};
207
208
let fill_frame = if let Some(fill) = fill {
209
fill.get_output()?.unwrap()
210
} else {
211
DataFrame::empty_with_schema(output_schema)
212
};
213
214
*self = Self::Shifting(ShiftState {
215
offset,
216
rows_received: 0,
217
rows_sent: 0,
218
buffer: VecDeque::new(),
219
fill: fill_frame,
220
seq: MorselSeq::default(),
221
})
222
}
223
}
224
225
match self {
226
Self::GatheringParams { offset, fill, .. } => {
227
offset.update_state(&mut recv[1..2], &mut [], state)?;
228
if let Some(fill) = fill {
229
fill.update_state(&mut recv[2..3], &mut [], state)?;
230
}
231
recv[0] = PortState::Blocked;
232
send[0] = PortState::Blocked;
233
},
234
Self::Shifting(shift_state) => {
235
if recv[0] == PortState::Done && shift_state.rows_sent < shift_state.rows_received {
236
send[0] = PortState::Ready;
237
} else {
238
recv[..1].swap_with_slice(send);
239
}
240
recv[1..].fill(PortState::Done);
241
},
242
Self::Done => {
243
recv.fill(PortState::Done);
244
send[0] = PortState::Done;
245
},
246
}
247
Ok(())
248
}
249
250
fn spawn<'env, 's>(
251
&'env mut self,
252
scope: &'s TaskScope<'s, 'env>,
253
recv_ports: &mut [Option<RecvPort<'_>>],
254
send_ports: &mut [Option<SendPort<'_>>],
255
state: &'s StreamingExecutionState,
256
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
257
) {
258
assert!(recv_ports.len() <= 3 && send_ports.len() == 1);
259
match self {
260
Self::GatheringParams {
261
offset,
262
fill,
263
output_schema: _,
264
} => {
265
assert!(recv_ports[0].is_none());
266
assert!(send_ports[0].is_none());
267
if recv_ports[1].is_some() {
268
offset.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);
269
}
270
if matches!(recv_ports.get(2), Some(Some(_))) {
271
fill.as_mut().unwrap().spawn(
272
scope,
273
&mut recv_ports[2..3],
274
&mut [],
275
state,
276
join_handles,
277
);
278
}
279
},
280
Self::Shifting(shift_state) => {
281
assert!(recv_ports[1..].iter().all(|p| p.is_none()));
282
let recv = recv_ports[0].take().map(|p| p.serial());
283
let send = send_ports[0].take().unwrap().serial();
284
join_handles.push(if shift_state.offset >= 0 {
285
scope.spawn_task(TaskPriority::High, shift_state.shift_positive(recv, send))
286
} else if let Some(r) = recv {
287
scope.spawn_task(TaskPriority::High, shift_state.shift_negative(r, send))
288
} else {
289
scope.spawn_task(TaskPriority::High, shift_state.flush_negative(send, state))
290
});
291
},
292
Self::Done => unreachable!(),
293
}
294
}
295
}
296
297