Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/in_memory_source.rs
6939 views
1
use std::sync::Arc;
2
use std::sync::atomic::{AtomicU64, Ordering};
3
4
use super::compute_node_prelude::*;
5
use crate::async_primitives::wait_group::WaitGroup;
6
use crate::morsel::{MorselSeq, SourceToken, get_ideal_morsel_size};
7
8
pub struct InMemorySourceNode {
9
source: Option<Arc<DataFrame>>,
10
morsel_size: usize,
11
seq: AtomicU64,
12
seq_offset: MorselSeq,
13
}
14
15
impl InMemorySourceNode {
16
pub fn new(source: Arc<DataFrame>, seq_offset: MorselSeq) -> Self {
17
InMemorySourceNode {
18
source: Some(source),
19
morsel_size: 0,
20
seq: AtomicU64::new(0),
21
seq_offset,
22
}
23
}
24
}
25
26
impl ComputeNode for InMemorySourceNode {
27
fn name(&self) -> &str {
28
"in-memory-source"
29
}
30
31
fn update_state(
32
&mut self,
33
recv: &mut [PortState],
34
send: &mut [PortState],
35
state: &StreamingExecutionState,
36
) -> PolarsResult<()> {
37
assert!(recv.is_empty());
38
assert!(send.len() == 1);
39
40
if self.morsel_size == 0 {
41
let len = self.source.as_ref().unwrap().height();
42
let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1);
43
let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);
44
self.morsel_size = len.div_ceil(morsel_count).max(1);
45
self.seq = AtomicU64::new(0);
46
}
47
48
// As a temporary hack for some nodes (like the FunctionIR::FastCount)
49
// node that rely on an empty input, always ensure we send at least one
50
// morsel.
51
// TODO: remove this hack.
52
let exhausted = if let Some(src) = &self.source {
53
let seq = self.seq.load(Ordering::Relaxed);
54
seq > 0 && seq * self.morsel_size as u64 >= src.height() as u64
55
} else {
56
true
57
};
58
if send[0] == PortState::Done || exhausted {
59
send[0] = PortState::Done;
60
self.source = None;
61
} else {
62
send[0] = PortState::Ready;
63
}
64
Ok(())
65
}
66
67
fn spawn<'env, 's>(
68
&'env mut self,
69
scope: &'s TaskScope<'s, 'env>,
70
recv_ports: &mut [Option<RecvPort<'_>>],
71
send_ports: &mut [Option<SendPort<'_>>],
72
_state: &'s StreamingExecutionState,
73
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
74
) {
75
assert!(recv_ports.is_empty() && send_ports.len() == 1);
76
let senders = send_ports[0].take().unwrap().parallel();
77
let source = self.source.as_ref().unwrap();
78
79
// TODO: can this just be serial, using the work distributor?
80
let source_token = SourceToken::new();
81
for mut send in senders {
82
let slf = &*self;
83
let source_token = source_token.clone();
84
join_handles.push(scope.spawn_task(TaskPriority::Low, async move {
85
let wait_group = WaitGroup::default();
86
loop {
87
let seq = slf.seq.fetch_add(1, Ordering::Relaxed);
88
let offset = (seq as usize * slf.morsel_size) as i64;
89
let df = source.slice(offset, slf.morsel_size);
90
91
// TODO: remove this 'always sent at least one morsel'
92
// condition, see update_state.
93
if df.height() == 0 && seq > 0 {
94
break;
95
}
96
97
let morsel_seq = MorselSeq::new(seq).offset_by(slf.seq_offset);
98
let mut morsel = Morsel::new(df, morsel_seq, source_token.clone());
99
morsel.set_consume_token(wait_group.token());
100
if send.send(morsel).await.is_err() {
101
break;
102
}
103
104
wait_group.wait().await;
105
if source_token.stop_requested() {
106
break;
107
}
108
}
109
110
Ok(())
111
}));
112
}
113
}
114
}
115
116