Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/negative_slice.rs
6939 views
1
use std::collections::VecDeque;
2
use std::sync::Arc;
3
4
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
5
6
use super::compute_node_prelude::*;
7
use crate::nodes::in_memory_source::InMemorySourceNode;
8
9
/// A node that will pass-through up to length rows, starting at start_offset.
10
/// Since start_offset must be non-negative this can be done in a streaming
11
/// manner.
12
enum NegativeSliceState {
13
Buffering(Buffer),
14
Source(InMemorySourceNode),
15
Done,
16
}
17
18
#[derive(Default)]
19
struct Buffer {
20
frames: VecDeque<DataFrame>,
21
total_len: usize,
22
}
23
24
pub struct NegativeSliceNode {
25
state: NegativeSliceState,
26
slice_offset: i64,
27
length: usize,
28
}
29
30
impl NegativeSliceNode {
31
pub fn new(slice_offset: i64, length: usize) -> Self {
32
assert!(slice_offset < 0);
33
Self {
34
state: NegativeSliceState::Buffering(Buffer::default()),
35
slice_offset,
36
length,
37
}
38
}
39
}
40
41
impl ComputeNode for NegativeSliceNode {
42
fn name(&self) -> &str {
43
"negative-slice"
44
}
45
46
fn update_state(
47
&mut self,
48
recv: &mut [PortState],
49
send: &mut [PortState],
50
state: &StreamingExecutionState,
51
) -> PolarsResult<()> {
52
use NegativeSliceState::*;
53
54
if send[0] == PortState::Done || self.length == 0 {
55
self.state = Done;
56
}
57
58
if recv[0] == PortState::Done {
59
if let Buffering(buffer) = &mut self.state {
60
// These offsets are relative to the start of buffer.
61
let mut signed_start_offset = buffer.total_len as i64 + self.slice_offset;
62
let signed_stop_offset =
63
signed_start_offset.saturating_add_unsigned(self.length as u64);
64
65
// Trim the frames in the buffer to just those that are relevant.
66
while buffer.total_len > 0
67
&& signed_start_offset >= buffer.frames.front().unwrap().height() as i64
68
{
69
let len = buffer.frames.pop_front().unwrap().height();
70
buffer.total_len -= len;
71
signed_start_offset -= len as i64;
72
}
73
74
while !buffer.frames.is_empty()
75
&& buffer.total_len as i64 - buffer.frames.back().unwrap().height() as i64
76
> signed_stop_offset
77
{
78
buffer.total_len -= buffer.frames.pop_back().unwrap().height();
79
}
80
81
if buffer.total_len == 0 {
82
self.state = Done;
83
} else {
84
let mut df = accumulate_dataframes_vertical_unchecked(buffer.frames.drain(..));
85
let clamped_start = signed_start_offset.max(0);
86
let len = (signed_stop_offset - clamped_start).max(0) as usize;
87
df = df.slice(clamped_start, len);
88
self.state =
89
Source(InMemorySourceNode::new(Arc::new(df), MorselSeq::default()));
90
}
91
}
92
}
93
94
match &mut self.state {
95
Buffering(_) => {
96
recv[0] = PortState::Ready;
97
send[0] = PortState::Blocked;
98
},
99
Source(node) => {
100
recv[0] = PortState::Done;
101
node.update_state(&mut [], send, state)?;
102
},
103
Done => {
104
recv[0] = PortState::Done;
105
send[0] = PortState::Done;
106
},
107
}
108
Ok(())
109
}
110
111
fn spawn<'env, 's>(
112
&'env mut self,
113
scope: &'s TaskScope<'s, 'env>,
114
recv_ports: &mut [Option<RecvPort<'_>>],
115
send_ports: &mut [Option<SendPort<'_>>],
116
state: &'s StreamingExecutionState,
117
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
118
) {
119
assert!(recv_ports.len() == 1 && send_ports.len() == 1);
120
match &mut self.state {
121
NegativeSliceState::Buffering(buffer) => {
122
let mut recv = recv_ports[0].take().unwrap().serial();
123
assert!(send_ports[0].is_none());
124
let max_buffer_needed = self.slice_offset.unsigned_abs() as usize;
125
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
126
while let Ok(morsel) = recv.recv().await {
127
buffer.total_len += morsel.df().height();
128
buffer.frames.push_back(morsel.into_df());
129
130
if buffer.total_len - buffer.frames.front().unwrap().height()
131
>= max_buffer_needed
132
{
133
buffer.total_len -= buffer.frames.pop_front().unwrap().height();
134
}
135
}
136
137
Ok(())
138
}));
139
},
140
NegativeSliceState::Source(in_memory_source_node) => {
141
assert!(recv_ports[0].is_none());
142
in_memory_source_node.spawn(scope, &mut [], send_ports, state, join_handles);
143
},
144
NegativeSliceState::Done => unreachable!(),
145
}
146
}
147
}
148
149