Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/repeat.rs
6939 views
1
use std::sync::Arc;
2
3
use polars_core::schema::Schema;
4
5
use super::compute_node_prelude::*;
6
use crate::async_primitives::wait_group::WaitGroup;
7
use crate::morsel::{SourceToken, get_ideal_morsel_size};
8
use crate::nodes::in_memory_sink::InMemorySinkNode;
9
pub enum RepeatNode {
10
GatheringParams {
11
value: InMemorySinkNode,
12
repeats: InMemorySinkNode,
13
},
14
Repeating {
15
value: DataFrame,
16
seq: MorselSeq,
17
repeats_left: usize,
18
},
19
}
20
21
impl RepeatNode {
22
pub fn new(value_schema: Arc<Schema>, repeats_schema: Arc<Schema>) -> Self {
23
assert!(value_schema.len() == 1);
24
assert!(repeats_schema.len() == 1);
25
Self::GatheringParams {
26
value: InMemorySinkNode::new(value_schema),
27
repeats: InMemorySinkNode::new(repeats_schema),
28
}
29
}
30
}
31
32
impl ComputeNode for RepeatNode {
33
fn name(&self) -> &str {
34
"repeat"
35
}
36
37
fn update_state(
38
&mut self,
39
recv: &mut [PortState],
40
send: &mut [PortState],
41
state: &StreamingExecutionState,
42
) -> PolarsResult<()> {
43
assert!(recv.len() == 2 && send.len() == 1);
44
45
if recv[0] == PortState::Done && recv[1] == PortState::Done {
46
if let Self::GatheringParams { value, repeats } = self {
47
let repeats = repeats.get_output()?.unwrap();
48
let repeats_item = repeats.get_columns()[0].get(0)?;
49
let repeats_left = repeats_item.extract::<usize>().unwrap();
50
51
let value = value.get_output()?.unwrap();
52
let seq = MorselSeq::default();
53
*self = Self::Repeating {
54
value,
55
seq,
56
repeats_left,
57
};
58
}
59
}
60
61
match self {
62
Self::GatheringParams { value, repeats } => {
63
value.update_state(&mut recv[0..1], &mut [], state)?;
64
repeats.update_state(&mut recv[1..2], &mut [], state)?;
65
send[0] = PortState::Blocked;
66
},
67
Self::Repeating { repeats_left, .. } => {
68
recv[0] = PortState::Done;
69
recv[1] = PortState::Done;
70
send[0] = if *repeats_left > 0 {
71
PortState::Ready
72
} else {
73
PortState::Done
74
};
75
},
76
}
77
Ok(())
78
}
79
80
fn spawn<'env, 's>(
81
&'env mut self,
82
scope: &'s TaskScope<'s, 'env>,
83
recv_ports: &mut [Option<RecvPort<'_>>],
84
send_ports: &mut [Option<SendPort<'_>>],
85
state: &'s StreamingExecutionState,
86
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
87
) {
88
assert!(recv_ports.len() == 2 && send_ports.len() == 1);
89
match self {
90
Self::GatheringParams { value, repeats } => {
91
assert!(send_ports[0].is_none());
92
if recv_ports[0].is_some() {
93
value.spawn(scope, &mut recv_ports[0..1], &mut [], state, join_handles);
94
}
95
if recv_ports[1].is_some() {
96
repeats.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);
97
}
98
},
99
Self::Repeating {
100
value,
101
seq,
102
repeats_left,
103
} => {
104
assert!(recv_ports[0].is_none());
105
assert!(recv_ports[1].is_none());
106
107
let mut send = send_ports[0].take().unwrap().serial();
108
109
let ideal_morsel_count = (*repeats_left / get_ideal_morsel_size()).max(1);
110
let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);
111
let morsel_size = repeats_left.div_ceil(morsel_count).max(1);
112
113
join_handles.push(scope.spawn_task(TaskPriority::Low, async move {
114
let source_token = SourceToken::new();
115
116
let wait_group = WaitGroup::default();
117
while *repeats_left > 0 && !source_token.stop_requested() {
118
let height = morsel_size.min(*repeats_left);
119
let df = value.new_from_index(0, height);
120
let mut morsel = Morsel::new(df, *seq, source_token.clone());
121
morsel.set_consume_token(wait_group.token());
122
123
*seq = seq.successor();
124
*repeats_left -= height;
125
126
if send.send(morsel).await.is_err() {
127
break;
128
}
129
wait_group.wait().await;
130
}
131
132
Ok(())
133
}));
134
},
135
}
136
}
137
}
138
139