Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/dynamic_slice.rs
6939 views
1
use std::sync::Arc;
2
3
use polars_core::schema::Schema;
4
5
use super::compute_node_prelude::*;
6
use crate::nodes::in_memory_sink::InMemorySinkNode;
7
use crate::nodes::negative_slice::NegativeSliceNode;
8
use crate::nodes::streaming_slice::StreamingSliceNode;
9
10
/// A node that will dispatch either to StreamingSlice or NegativeSlice
11
/// depending on the offset which is dynamically dispatched.
12
pub enum DynamicSliceNode {
13
GatheringParams {
14
offset: InMemorySinkNode,
15
length: InMemorySinkNode,
16
},
17
Streaming(StreamingSliceNode),
18
Negative(NegativeSliceNode),
19
}
20
21
impl DynamicSliceNode {
22
pub fn new(offset_schema: Arc<Schema>, length_schema: Arc<Schema>) -> Self {
23
assert!(offset_schema.len() == 1);
24
assert!(length_schema.len() == 1);
25
Self::GatheringParams {
26
offset: InMemorySinkNode::new(offset_schema),
27
length: InMemorySinkNode::new(length_schema),
28
}
29
}
30
}
31
32
impl ComputeNode for DynamicSliceNode {
33
fn name(&self) -> &str {
34
"dynamic-slice"
35
}
36
37
fn update_state(
38
&mut self,
39
recv: &mut [PortState],
40
send: &mut [PortState],
41
state: &StreamingExecutionState,
42
) -> PolarsResult<()> {
43
assert!(recv.len() == 3 && send.len() == 1);
44
45
if recv[1] == PortState::Done && recv[2] == PortState::Done {
46
if let Self::GatheringParams { offset, length } = self {
47
let offset = offset.get_output()?.unwrap();
48
let length = length.get_output()?.unwrap();
49
let offset_item = offset.get_columns()[0].get(0)?;
50
let length_item = length.get_columns()[0].get(0)?;
51
let offset = offset_item.extract::<i64>().unwrap_or(0);
52
let length = length_item.extract::<usize>().unwrap_or(usize::MAX);
53
if let Ok(non_neg_offset) = offset.try_into() {
54
*self = Self::Streaming(StreamingSliceNode::new(non_neg_offset, length));
55
} else {
56
*self = Self::Negative(NegativeSliceNode::new(offset, length));
57
}
58
}
59
}
60
61
match self {
62
Self::GatheringParams { offset, length } => {
63
offset.update_state(&mut recv[1..2], &mut [], state)?;
64
length.update_state(&mut recv[2..3], &mut [], state)?;
65
recv[0] = PortState::Blocked;
66
send[0] = PortState::Blocked;
67
},
68
Self::Streaming(node) => {
69
node.update_state(&mut recv[0..1], send, state)?;
70
recv[1] = PortState::Done;
71
recv[2] = PortState::Done;
72
},
73
Self::Negative(node) => {
74
node.update_state(&mut recv[0..1], send, state)?;
75
recv[1] = PortState::Done;
76
recv[2] = PortState::Done;
77
},
78
}
79
Ok(())
80
}
81
82
fn spawn<'env, 's>(
83
&'env mut self,
84
scope: &'s TaskScope<'s, 'env>,
85
recv_ports: &mut [Option<RecvPort<'_>>],
86
send_ports: &mut [Option<SendPort<'_>>],
87
state: &'s StreamingExecutionState,
88
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
89
) {
90
assert!(recv_ports.len() == 3 && send_ports.len() == 1);
91
match self {
92
Self::GatheringParams { offset, length } => {
93
assert!(recv_ports[0].is_none());
94
assert!(send_ports[0].is_none());
95
if recv_ports[1].is_some() {
96
offset.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);
97
}
98
if recv_ports[2].is_some() {
99
length.spawn(scope, &mut recv_ports[2..3], &mut [], state, join_handles);
100
}
101
},
102
Self::Streaming(node) => {
103
node.spawn(
104
scope,
105
&mut recv_ports[0..1],
106
send_ports,
107
state,
108
join_handles,
109
);
110
assert!(recv_ports[1].is_none());
111
assert!(recv_ports[2].is_none());
112
},
113
Self::Negative(node) => {
114
node.spawn(
115
scope,
116
&mut recv_ports[0..1],
117
send_ports,
118
state,
119
join_handles,
120
);
121
assert!(recv_ports[1].is_none());
122
assert!(recv_ports[2].is_none());
123
},
124
}
125
}
126
}
127
128