Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/rolling_group_by.rs
7884 views
1
use std::sync::Arc;
2
3
use chrono_tz::Tz;
4
use polars_core::frame::DataFrame;
5
use polars_core::prelude::{Column, DataType, GroupsType, TimeUnit};
6
use polars_core::schema::Schema;
7
use polars_error::{PolarsError, PolarsResult, polars_bail, polars_ensure};
8
use polars_expr::state::ExecutionState;
9
use polars_time::prelude::{RollingWindower, ensure_duration_matches_dtype};
10
use polars_time::{ClosedWindow, Duration};
11
use polars_utils::IdxSize;
12
use polars_utils::pl_str::PlSmallStr;
13
14
use super::ComputeNode;
15
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;
16
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
17
use crate::async_primitives::distributor_channel::distributor_channel;
18
use crate::async_primitives::wait_group::WaitGroup;
19
use crate::execute::StreamingExecutionState;
20
use crate::expression::StreamExpr;
21
use crate::graph::PortState;
22
use crate::morsel::{Morsel, MorselSeq, SourceToken};
23
use crate::pipe::{RecvPort, SendPort};
24
25
type NextWindows = (Vec<[IdxSize; 2]>, DataFrame, Column);
26
27
pub struct RollingGroupBy {
28
buf_df: DataFrame,
29
/// How many `buf_df` rows did we discard of already?
30
buf_df_offset: IdxSize,
31
/// Casted index column, which may need to keep around old values.
32
buf_index_column: Column,
33
/// Uncasted index column.
34
buf_key_column: Column,
35
36
seq: MorselSeq,
37
38
slice_offset: IdxSize,
39
slice_length: IdxSize,
40
41
index_column: PlSmallStr,
42
windower: RollingWindower,
43
aggs: Arc<[(PlSmallStr, StreamExpr)]>,
44
}
45
impl RollingGroupBy {
46
pub fn new(
47
schema: Arc<Schema>,
48
index_column: PlSmallStr,
49
period: Duration,
50
offset: Duration,
51
closed: ClosedWindow,
52
slice: Option<(IdxSize, IdxSize)>,
53
aggs: Arc<[(PlSmallStr, StreamExpr)]>,
54
) -> PolarsResult<Self> {
55
polars_ensure!(
56
!period.is_zero() && !period.negative(),
57
ComputeError: "rolling window period should be strictly positive",
58
);
59
60
let key_dtype = schema.get(&index_column).unwrap();
61
ensure_duration_matches_dtype(period, key_dtype, "period")?;
62
ensure_duration_matches_dtype(offset, key_dtype, "offset")?;
63
64
use DataType as DT;
65
let (tu, tz) = match key_dtype {
66
DT::Datetime(tu, tz) => (*tu, tz.clone()),
67
DT::Date => (TimeUnit::Microseconds, None),
68
DT::UInt32 | DT::UInt64 | DT::Int64 | DT::Int32 => (TimeUnit::Nanoseconds, None),
69
dt => polars_bail!(
70
ComputeError:
71
"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}",
72
dt
73
),
74
};
75
76
let buf_df = DataFrame::empty_with_arc_schema(schema.clone());
77
let buf_key_column = Column::new_empty(index_column.clone(), key_dtype);
78
let buf_index_column =
79
Column::new_empty(index_column.clone(), &DT::Datetime(tu, tz.clone()));
80
81
// @NOTE: This is a bit strange since it ignores errors, but it mirrors the in-memory
82
// engine.
83
let tz = tz.and_then(|tz| tz.parse::<Tz>().ok());
84
let windower = RollingWindower::new(period, offset, closed, tu, tz);
85
86
let (slice_offset, slice_length) = slice.unwrap_or((0, IdxSize::MAX));
87
88
Ok(Self {
89
buf_df,
90
buf_df_offset: 0,
91
buf_index_column,
92
buf_key_column,
93
seq: MorselSeq::default(),
94
slice_offset,
95
slice_length,
96
index_column,
97
windower,
98
aggs,
99
})
100
}
101
102
async fn evaluate_one(
103
windows: Vec<[IdxSize; 2]>,
104
key: Column,
105
aggs: &[(PlSmallStr, StreamExpr)],
106
state: &ExecutionState,
107
mut df: DataFrame,
108
) -> PolarsResult<DataFrame> {
109
assert_eq!(windows.len(), key.len());
110
111
let groups = GroupsType::new_slice(windows, true, true).into_sliceable();
112
113
// @NOTE:
114
// Rechunk so we can use specialized rolling kernels.
115
//
116
// This can be removed if / when the rolling kernels are chunking aware.
117
df.rechunk_mut();
118
119
let mut columns = Vec::with_capacity(1 + aggs.len());
120
let height = key.len();
121
columns.push(key);
122
for (name, agg) in aggs.iter() {
123
let mut agg = agg.evaluate_on_groups(&df, &groups, state).await?;
124
let agg = agg.finalize();
125
columns.push(agg.with_name(name.clone()));
126
}
127
128
Ok(unsafe { DataFrame::new_no_checks(height, columns) })
129
}
130
131
/// Progress the state and get the next available evaluation windows, data and key.
132
fn next_windows(&mut self, finalize: bool) -> PolarsResult<Option<NextWindows>> {
133
let buf_index_col_dt = self.buf_index_column.datetime()?;
134
let mut time = Vec::new();
135
time.extend(
136
buf_index_col_dt
137
.physical()
138
.downcast_iter()
139
.map(|arr| arr.values().as_slice()),
140
);
141
142
let mut windows = Vec::new();
143
let num_retired = if finalize {
144
self.windower.finalize(&time, &mut windows);
145
self.buf_key_column.len() as IdxSize
146
} else {
147
self.windower.insert(&time, &mut windows)?
148
};
149
150
if num_retired == 0 && windows.is_empty() {
151
return Ok(None);
152
}
153
154
let start_row_offset = self.buf_df_offset;
155
156
self.buf_index_column = self.buf_index_column.slice(num_retired as i64, usize::MAX);
157
let new_buf_df = self.buf_df.slice(num_retired as i64, usize::MAX);
158
let data = std::mem::replace(&mut self.buf_df, new_buf_df);
159
self.buf_df_offset += num_retired;
160
161
if windows.is_empty() {
162
return Ok(None);
163
}
164
165
let key;
166
(key, self.buf_key_column) = self.buf_key_column.split_at(windows.len() as i64);
167
let key = key.slice(self.slice_offset as i64, self.slice_length as usize);
168
169
let offset = windows[0][0];
170
let end = windows.last().unwrap();
171
let end = end[0] + end[1];
172
173
if self.slice_offset as usize > windows.len() {
174
self.slice_offset -= windows.len() as IdxSize;
175
windows.clear();
176
} else if self.slice_offset > 0 {
177
let offset = self.slice_offset as usize;
178
self.slice_offset = self.slice_offset.saturating_sub(windows.len() as IdxSize);
179
windows.drain(..offset);
180
}
181
182
windows.truncate(windows.len().min(self.slice_length as usize));
183
self.slice_length -= windows.len() as IdxSize;
184
185
if windows.is_empty() {
186
return Ok(None);
187
}
188
189
// Prune the data that is not covered by the windows and update the windows accordingly.
190
windows.iter_mut().for_each(|[s, _]| *s -= offset);
191
let data = data.slice(
192
(offset - start_row_offset) as i64,
193
(end - start_row_offset) as usize,
194
);
195
196
Ok(Some((windows, data, key)))
197
}
198
}
199
200
impl ComputeNode for RollingGroupBy {
201
fn name(&self) -> &str {
202
"rolling-group-by"
203
}
204
205
fn update_state(
206
&mut self,
207
recv: &mut [PortState],
208
send: &mut [PortState],
209
_state: &StreamingExecutionState,
210
) -> PolarsResult<()> {
211
assert!(recv.len() == 1 && send.len() == 1);
212
213
if self.slice_length == 0 {
214
recv[0] = PortState::Done;
215
send[0] = PortState::Done;
216
std::mem::take(&mut self.buf_df);
217
return Ok(());
218
}
219
220
if send[0] == PortState::Done {
221
recv[0] = PortState::Done;
222
std::mem::take(&mut self.buf_df);
223
} else if recv[0] == PortState::Done {
224
if self.buf_df.is_empty() {
225
send[0] = PortState::Done;
226
} else {
227
send[0] = PortState::Ready;
228
}
229
} else {
230
recv.swap_with_slice(send);
231
}
232
233
Ok(())
234
}
235
236
fn spawn<'env, 's>(
237
&'env mut self,
238
scope: &'s TaskScope<'s, 'env>,
239
recv_ports: &mut [Option<RecvPort<'_>>],
240
send_ports: &mut [Option<SendPort<'_>>],
241
state: &'s StreamingExecutionState,
242
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
243
) {
244
assert!(recv_ports.len() == 1 && send_ports.len() == 1);
245
246
let Some(recv) = recv_ports[0].take() else {
247
// We no longer have to receive data. Finalize and send all remaining data.
248
assert!(!self.buf_df.is_empty());
249
assert!(self.slice_length > 0);
250
let mut send = send_ports[0].take().unwrap().serial();
251
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
252
if let Some((windows, df, key)) = self.next_windows(true)? {
253
let df = Self::evaluate_one(
254
windows,
255
key,
256
&self.aggs,
257
&state.in_memory_exec_state,
258
df,
259
)
260
.await?;
261
262
_ = send
263
.send(Morsel::new(df, self.seq.successor(), SourceToken::new()))
264
.await;
265
}
266
267
self.buf_df = self.buf_df.clear();
268
self.buf_key_column = self.buf_key_column.clear();
269
self.buf_index_column = self.buf_index_column.clear();
270
271
Ok(())
272
}));
273
return;
274
};
275
276
let mut recv = recv.serial();
277
let send = send_ports[0].take().unwrap().parallel();
278
279
let (mut distributor, rxs) = distributor_channel::<(Morsel, Column, Vec<[IdxSize; 2]>)>(
280
send.len(),
281
*DEFAULT_DISTRIBUTOR_BUFFER_SIZE,
282
);
283
284
// Worker tasks.
285
//
286
// These evaluate the aggregations.
287
join_handles.extend(rxs.into_iter().zip(send).map(|(mut rx, mut tx)| {
288
let wg = WaitGroup::default();
289
let aggs = self.aggs.clone();
290
let state = state.in_memory_exec_state.split();
291
scope.spawn_task(TaskPriority::High, async move {
292
while let Ok((mut morsel, key, windows)) = rx.recv().await {
293
morsel = morsel
294
.async_try_map::<PolarsError, _, _>(async |df| {
295
Self::evaluate_one(windows, key, &aggs, &state, df).await
296
})
297
.await?;
298
morsel.set_consume_token(wg.token());
299
300
if tx.send(morsel).await.is_err() {
301
break;
302
}
303
wg.wait().await;
304
}
305
306
Ok(())
307
})
308
}));
309
310
// Distributor task.
311
//
312
// This finds boundaries to distribute to worker threads over.
313
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
314
while let Ok(morsel) = recv.recv().await
315
&& self.slice_length > 0
316
{
317
let (df, seq, source_token, wait_token) = morsel.into_inner();
318
self.seq = seq;
319
drop(wait_token);
320
321
if df.height() == 0 {
322
continue;
323
}
324
325
let morsel_index_column = df.column(&self.index_column)?;
326
polars_ensure!(
327
morsel_index_column.null_count() == 0,
328
ComputeError: "null values in `rolling` not supported, fill nulls."
329
);
330
331
self.buf_key_column.append(morsel_index_column)?;
332
333
use DataType as DT;
334
let morsel_index_column = match morsel_index_column.dtype() {
335
DT::Datetime(_, _) => morsel_index_column.clone(),
336
DT::Date => {
337
morsel_index_column.cast(&DT::Datetime(TimeUnit::Microseconds, None))?
338
},
339
DT::UInt32 | DT::UInt64 | DT::Int32 => morsel_index_column
340
.cast(&DT::Int64)?
341
.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?,
342
DT::Int64 => {
343
morsel_index_column.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?
344
},
345
_ => unreachable!(),
346
};
347
self.buf_index_column.append(&morsel_index_column)?;
348
self.buf_df.vstack_mut_owned(df)?;
349
350
if let Some((windows, df, key)) = self.next_windows(false)? {
351
if distributor
352
.send((Morsel::new(df, seq, source_token), key, windows))
353
.await
354
.is_err()
355
{
356
break;
357
}
358
}
359
}
360
361
Ok(())
362
}));
363
}
364
}
365
366