Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/in_memory_source.rs
8420 views
1
use std::sync::Arc;
2
use std::sync::atomic::{AtomicU64, Ordering};
3
4
use super::compute_node_prelude::*;
5
use crate::async_primitives::wait_group::WaitGroup;
6
use crate::morsel::{MorselSeq, SourceToken, get_ideal_morsel_size};
7
8
#[derive(Debug)]
9
pub struct InMemorySourceNode {
10
source: Option<Arc<DataFrame>>,
11
morsel_size: usize,
12
seq: AtomicU64,
13
seq_offset: MorselSeq,
14
}
15
16
impl InMemorySourceNode {
17
pub fn new(source: Arc<DataFrame>, seq_offset: MorselSeq) -> Self {
18
InMemorySourceNode {
19
source: Some(source),
20
morsel_size: 0,
21
seq: AtomicU64::new(0),
22
seq_offset,
23
}
24
}
25
26
pub fn new_no_morsel_split(source: Arc<DataFrame>, seq_offset: MorselSeq) -> Self {
27
let morsel_size = source.height();
28
29
InMemorySourceNode {
30
source: Some(source),
31
morsel_size,
32
seq: AtomicU64::new(0),
33
seq_offset,
34
}
35
}
36
}
37
38
impl ComputeNode for InMemorySourceNode {
39
fn name(&self) -> &str {
40
"in-memory-source"
41
}
42
43
fn update_state(
44
&mut self,
45
recv: &mut [PortState],
46
send: &mut [PortState],
47
state: &StreamingExecutionState,
48
) -> PolarsResult<()> {
49
assert!(recv.is_empty());
50
assert!(send.len() == 1);
51
52
if self.morsel_size == 0 {
53
let len = self.source.as_ref().unwrap().height();
54
let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1);
55
let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);
56
self.morsel_size = len.div_ceil(morsel_count).max(1);
57
self.seq = AtomicU64::new(0);
58
}
59
60
// As a temporary hack for some nodes (like the FunctionIR::FastCount)
61
// node that rely on an empty input, always ensure we send at least one
62
// morsel.
63
// TODO: remove this hack.
64
let exhausted = if let Some(src) = &self.source {
65
let seq = self.seq.load(Ordering::Relaxed);
66
seq > 0 && seq * self.morsel_size as u64 >= src.height() as u64
67
} else {
68
true
69
};
70
if send[0] == PortState::Done || exhausted {
71
send[0] = PortState::Done;
72
self.source = None;
73
} else {
74
send[0] = PortState::Ready;
75
}
76
Ok(())
77
}
78
79
fn spawn<'env, 's>(
80
&'env mut self,
81
scope: &'s TaskScope<'s, 'env>,
82
recv_ports: &mut [Option<RecvPort<'_>>],
83
send_ports: &mut [Option<SendPort<'_>>],
84
_state: &'s StreamingExecutionState,
85
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
86
) {
87
assert!(recv_ports.is_empty() && send_ports.len() == 1);
88
let senders = send_ports[0].take().unwrap().parallel();
89
let source = self.source.as_ref().unwrap();
90
91
// TODO: can this just be serial, using the work distributor?
92
let source_token = SourceToken::new();
93
for mut send in senders {
94
let slf = &*self;
95
let source_token = source_token.clone();
96
join_handles.push(scope.spawn_task(TaskPriority::Low, async move {
97
let wait_group = WaitGroup::default();
98
loop {
99
let seq = slf.seq.fetch_add(1, Ordering::Relaxed);
100
let offset = (seq as usize * slf.morsel_size) as i64;
101
let df = source.slice(offset, slf.morsel_size);
102
103
// TODO: remove this 'always sent at least one morsel'
104
// condition, see update_state.
105
if df.height() == 0 && seq > 0 {
106
break;
107
}
108
109
let morsel_seq = MorselSeq::new(seq).offset_by(slf.seq_offset);
110
let mut morsel = Morsel::new(df, morsel_seq, source_token.clone());
111
morsel.set_consume_token(wait_group.token());
112
if send.send(morsel).await.is_err() {
113
break;
114
}
115
116
wait_group.wait().await;
117
if source_token.stop_requested() {
118
break;
119
}
120
}
121
122
Ok(())
123
}));
124
}
125
}
126
}
127
128