Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/joins/in_memory.rs
6939 views
1
use std::sync::Arc;
2
3
use polars_core::schema::Schema;
4
5
use crate::nodes::compute_node_prelude::*;
6
use crate::nodes::in_memory_sink::InMemorySinkNode;
7
use crate::nodes::in_memory_source::InMemorySourceNode;
8
9
enum InMemoryJoinState {
10
Sink {
11
left: InMemorySinkNode,
12
right: InMemorySinkNode,
13
},
14
Source(InMemorySourceNode),
15
Done,
16
}
17
18
pub struct InMemoryJoinNode {
19
state: InMemoryJoinState,
20
joiner: Arc<dyn Fn(DataFrame, DataFrame) -> PolarsResult<DataFrame> + Send + Sync>,
21
}
22
23
impl InMemoryJoinNode {
24
pub fn new(
25
left_input_schema: Arc<Schema>,
26
right_input_schema: Arc<Schema>,
27
joiner: Arc<dyn Fn(DataFrame, DataFrame) -> PolarsResult<DataFrame> + Send + Sync>,
28
) -> Self {
29
Self {
30
state: InMemoryJoinState::Sink {
31
left: InMemorySinkNode::new(left_input_schema),
32
right: InMemorySinkNode::new(right_input_schema),
33
},
34
joiner,
35
}
36
}
37
}
38
39
impl ComputeNode for InMemoryJoinNode {
40
fn name(&self) -> &str {
41
"in-memory-join"
42
}
43
44
fn update_state(
45
&mut self,
46
recv: &mut [PortState],
47
send: &mut [PortState],
48
state: &StreamingExecutionState,
49
) -> PolarsResult<()> {
50
assert!(recv.len() == 2 && send.len() == 1);
51
52
// If the output doesn't want any more data, transition to being done.
53
if send[0] == PortState::Done && !matches!(self.state, InMemoryJoinState::Done) {
54
self.state = InMemoryJoinState::Done;
55
}
56
57
// If the input is done, transition to being a source.
58
if let InMemoryJoinState::Sink { left, right } = &mut self.state {
59
if recv[0] == PortState::Done && recv[1] == PortState::Done {
60
let left_df = left.get_output()?.unwrap();
61
let right_df = right.get_output()?.unwrap();
62
let source_node = InMemorySourceNode::new(
63
Arc::new((self.joiner)(left_df, right_df)?),
64
MorselSeq::default(),
65
);
66
self.state = InMemoryJoinState::Source(source_node);
67
}
68
}
69
70
match &mut self.state {
71
InMemoryJoinState::Sink { left, right, .. } => {
72
left.update_state(&mut recv[0..1], &mut [], state)?;
73
right.update_state(&mut recv[1..2], &mut [], state)?;
74
send[0] = PortState::Blocked;
75
},
76
InMemoryJoinState::Source(source_node) => {
77
recv[0] = PortState::Done;
78
recv[1] = PortState::Done;
79
source_node.update_state(&mut [], send, state)?;
80
},
81
InMemoryJoinState::Done => {
82
recv[0] = PortState::Done;
83
recv[1] = PortState::Done;
84
send[0] = PortState::Done;
85
},
86
}
87
Ok(())
88
}
89
90
fn is_memory_intensive_pipeline_blocker(&self) -> bool {
91
matches!(self.state, InMemoryJoinState::Sink { .. })
92
}
93
94
fn spawn<'env, 's>(
95
&'env mut self,
96
scope: &'s TaskScope<'s, 'env>,
97
recv_ports: &mut [Option<RecvPort<'_>>],
98
send_ports: &mut [Option<SendPort<'_>>],
99
state: &'s StreamingExecutionState,
100
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
101
) {
102
assert!(recv_ports.len() == 2);
103
assert!(send_ports.len() == 1);
104
match &mut self.state {
105
InMemoryJoinState::Sink { left, right, .. } => {
106
if recv_ports[0].is_some() {
107
left.spawn(scope, &mut recv_ports[0..1], &mut [], state, join_handles);
108
}
109
if recv_ports[1].is_some() {
110
right.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);
111
}
112
},
113
InMemoryJoinState::Source(source) => {
114
source.spawn(scope, &mut [], send_ports, state, join_handles)
115
},
116
InMemoryJoinState::Done => unreachable!(),
117
}
118
}
119
}
120
121