Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/gather_every.rs
7884 views
1
use polars_error::polars_ensure;
2
3
use super::compute_node_prelude::*;
4
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;
5
use crate::async_primitives::distributor_channel::distributor_channel;
6
use crate::async_primitives::wait_group::WaitGroup;
7
8
pub struct GatherEveryNode {
9
n: usize,
10
offset: usize,
11
}
12
13
impl GatherEveryNode {
14
pub fn new(n: usize, offset: usize) -> PolarsResult<Self> {
15
polars_ensure!(n > 0, InvalidOperation: "gather_every(n): n should be positive");
16
17
assert!(i64::try_from(n).unwrap() > 0);
18
assert!(i64::try_from(offset).unwrap() >= 0);
19
20
Ok(Self { n, offset })
21
}
22
}
23
24
impl ComputeNode for GatherEveryNode {
25
fn name(&self) -> &str {
26
"gather_every"
27
}
28
29
fn update_state(
30
&mut self,
31
recv: &mut [PortState],
32
send: &mut [PortState],
33
_state: &StreamingExecutionState,
34
) -> PolarsResult<()> {
35
assert!(recv.len() == 1 && send.len() == 1);
36
recv.swap_with_slice(send);
37
Ok(())
38
}
39
40
fn spawn<'env, 's>(
41
&'env mut self,
42
scope: &'s TaskScope<'s, 'env>,
43
recv_ports: &mut [Option<RecvPort<'_>>],
44
send_ports: &mut [Option<SendPort<'_>>],
45
_state: &'s StreamingExecutionState,
46
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
47
) {
48
assert!(recv_ports.len() == 1 && send_ports.len() == 1);
49
let mut receiver = recv_ports[0].take().unwrap().serial();
50
let senders = send_ports[0].take().unwrap().parallel();
51
52
let (mut distributor, distr_receivers) =
53
distributor_channel(senders.len(), *DEFAULT_DISTRIBUTOR_BUFFER_SIZE);
54
55
let n = self.n;
56
57
// To figure out the correct offsets we need to be serial.
58
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
59
while let Ok(morsel) = receiver.recv().await {
60
let height = morsel.df().height();
61
if self.offset >= height {
62
self.offset -= height;
63
continue;
64
}
65
66
if distributor.send((morsel, self.offset)).await.is_err() {
67
break;
68
}
69
70
// Calculates `offset = (offset - height) mod n` without under- and overflow.
71
self.offset += height.next_multiple_of(self.n) - height;
72
self.offset %= self.n;
73
}
74
75
Ok(())
76
}));
77
78
// But gathering the column can be done in parallel.
79
for (mut send, mut recv) in senders.into_iter().zip(distr_receivers) {
80
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
81
let wait_group = WaitGroup::default();
82
while let Ok((morsel, offset)) = recv.recv().await {
83
let mut morsel = morsel.try_map(|mut df| {
84
let column = &df.get_columns()[0];
85
let out = column
86
.gather_every(n, offset)?
87
.with_name(column.name().clone());
88
unsafe { df.get_columns_mut()[0] = out };
89
PolarsResult::Ok(df)
90
})?;
91
morsel.set_consume_token(wait_group.token());
92
if send.send(morsel).await.is_err() {
93
break;
94
}
95
wait_group.wait().await;
96
}
97
98
Ok(())
99
}));
100
}
101
}
102
}
103
104