Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/ewm.rs
7884 views
1
use polars_compute::ewm::EwmStateUpdate;
2
use polars_core::prelude::IntoColumn;
3
use polars_core::series::Series;
4
use polars_error::PolarsResult;
5
6
use super::ComputeNode;
7
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
8
use crate::execute::StreamingExecutionState;
9
use crate::graph::PortState;
10
use crate::pipe::{RecvPort, SendPort};
11
12
pub struct EwmNode {
13
name: &'static str,
14
state: Box<dyn EwmStateUpdate + Send>,
15
}
16
17
impl EwmNode {
18
pub fn new(name: &'static str, state: Box<dyn EwmStateUpdate + Send>) -> Self {
19
Self { name, state }
20
}
21
}
22
23
impl ComputeNode for EwmNode {
24
fn name(&self) -> &str {
25
self.name
26
}
27
28
fn update_state(
29
&mut self,
30
recv: &mut [PortState],
31
send: &mut [PortState],
32
_state: &StreamingExecutionState,
33
) -> PolarsResult<()> {
34
assert!(recv.len() == 1 && send.len() == 1);
35
recv.swap_with_slice(send);
36
Ok(())
37
}
38
39
fn spawn<'env, 's>(
40
&'env mut self,
41
scope: &'s TaskScope<'s, 'env>,
42
recv_ports: &mut [Option<RecvPort<'_>>],
43
send_ports: &mut [Option<SendPort<'_>>],
44
_state: &'s StreamingExecutionState,
45
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
46
) {
47
assert_eq!(recv_ports.len(), 1);
48
assert_eq!(send_ports.len(), 1);
49
50
let mut recv = recv_ports[0].take().unwrap().serial();
51
let mut send = send_ports[0].take().unwrap().serial();
52
53
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
54
while let Ok(mut morsel) = recv.recv().await {
55
let df = morsel.df_mut();
56
57
debug_assert_eq!(df.width(), 1);
58
59
unsafe {
60
let c = df.get_columns_mut().get_mut(0).unwrap();
61
62
*c = Series::from_chunks_and_dtype_unchecked(
63
c.name().clone(),
64
vec![self.state.ewm_state_update(
65
c.as_materialized_series().rechunk().chunks()[0].as_ref(),
66
)],
67
c.dtype(),
68
)
69
.into_column()
70
}
71
72
if send.send(morsel).await.is_err() {
73
break;
74
}
75
}
76
77
Ok(())
78
}));
79
}
80
}
81
82