Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/cum_agg.rs
6939 views
1
use polars_core::prelude::{AnyValue, IntoColumn};
2
use polars_core::utils::last_non_null;
3
use polars_error::PolarsResult;
4
use polars_ops::series::{
5
cum_count_with_init, cum_max_with_init, cum_min_with_init, cum_prod_with_init,
6
cum_sum_with_init,
7
};
8
9
use super::ComputeNode;
10
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
11
use crate::execute::StreamingExecutionState;
12
use crate::graph::PortState;
13
use crate::pipe::{RecvPort, SendPort};
14
15
pub struct CumAggNode {
16
state: AnyValue<'static>,
17
kind: CumAggKind,
18
}
19
20
#[derive(Debug, Clone, Copy)]
21
pub enum CumAggKind {
22
Min,
23
Max,
24
Sum,
25
Count,
26
Prod,
27
}
28
29
impl CumAggNode {
30
pub fn new(kind: CumAggKind) -> Self {
31
Self {
32
state: AnyValue::Null,
33
kind,
34
}
35
}
36
}
37
38
impl ComputeNode for CumAggNode {
39
fn name(&self) -> &str {
40
match self.kind {
41
CumAggKind::Min => "cum_min",
42
CumAggKind::Max => "cum_max",
43
CumAggKind::Sum => "cum_sum",
44
CumAggKind::Count => "cum_count",
45
CumAggKind::Prod => "cum_prod",
46
}
47
}
48
49
fn update_state(
50
&mut self,
51
recv: &mut [PortState],
52
send: &mut [PortState],
53
_state: &StreamingExecutionState,
54
) -> PolarsResult<()> {
55
assert!(recv.len() == 1 && send.len() == 1);
56
57
recv.swap_with_slice(send);
58
Ok(())
59
}
60
61
fn spawn<'env, 's>(
62
&'env mut self,
63
scope: &'s TaskScope<'s, 'env>,
64
recv_ports: &mut [Option<RecvPort<'_>>],
65
send_ports: &mut [Option<SendPort<'_>>],
66
_state: &'s StreamingExecutionState,
67
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
68
) {
69
assert_eq!(recv_ports.len(), 1);
70
assert_eq!(send_ports.len(), 1);
71
72
let mut recv = recv_ports[0].take().unwrap().serial();
73
let mut send = send_ports[0].take().unwrap().serial();
74
75
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
76
while let Ok(mut m) = recv.recv().await {
77
assert_eq!(m.df().width(), 1);
78
if m.df().height() == 0 {
79
continue;
80
}
81
82
let s = m.df()[0].as_materialized_series();
83
let out = match self.kind {
84
CumAggKind::Min => cum_min_with_init(s, false, &self.state),
85
CumAggKind::Max => cum_max_with_init(s, false, &self.state),
86
CumAggKind::Sum => cum_sum_with_init(s, false, &self.state),
87
CumAggKind::Count => {
88
cum_count_with_init(s, false, self.state.extract().unwrap_or_default())
89
},
90
CumAggKind::Prod => cum_prod_with_init(s, false, &self.state),
91
}?;
92
93
// Find the last non-null value and set that as the state.
94
let last_non_null_idx = if out.has_nulls() {
95
last_non_null(out.chunks().iter().map(|c| c.validity()), out.len())
96
} else {
97
Some(out.len() - 1)
98
};
99
if let Some(idx) = last_non_null_idx {
100
self.state = out.get(idx).unwrap().into_static();
101
}
102
*m.df_mut() = out.into_column().into_frame();
103
104
if send.send(m).await.is_err() {
105
break;
106
}
107
}
108
109
Ok(())
110
}));
111
}
112
}
113
114