Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/dynamic_group_by.rs
7884 views
1
use std::sync::Arc;
2
3
use arrow::legacy::time_zone::Tz;
4
use polars_core::frame::DataFrame;
5
use polars_core::prelude::{Column, DataType, GroupsType, Int64Chunked, IntoColumn, TimeUnit};
6
use polars_core::schema::Schema;
7
use polars_core::series::IsSorted;
8
use polars_error::{PolarsError, PolarsResult, polars_bail, polars_ensure};
9
use polars_expr::state::ExecutionState;
10
use polars_time::prelude::{GroupByDynamicWindower, Label, ensure_duration_matches_dtype};
11
use polars_time::{DynamicGroupOptions, LB_NAME, UB_NAME};
12
use polars_utils::IdxSize;
13
use polars_utils::pl_str::PlSmallStr;
14
15
use super::ComputeNode;
16
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;
17
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
18
use crate::async_primitives::distributor_channel::distributor_channel;
19
use crate::async_primitives::wait_group::WaitGroup;
20
use crate::execute::StreamingExecutionState;
21
use crate::expression::StreamExpr;
22
use crate::graph::PortState;
23
use crate::morsel::{Morsel, MorselSeq, SourceToken};
24
use crate::pipe::{RecvPort, SendPort};
25
26
type NextWindows = (Vec<[IdxSize; 2]>, Vec<i64>, Vec<i64>, DataFrame);
27
28
pub struct DynamicGroupBy {
29
buf_df: DataFrame,
30
/// How many `buf_df` rows did we discard of already?
31
buf_df_offset: IdxSize,
32
buf_index_column: Column,
33
34
seq: MorselSeq,
35
36
slice_offset: IdxSize,
37
slice_length: IdxSize,
38
39
group_by: Option<PlSmallStr>,
40
index_column: PlSmallStr,
41
index_column_idx: usize,
42
label: Label,
43
include_boundaries: bool,
44
windower: GroupByDynamicWindower,
45
aggs: Arc<[(PlSmallStr, StreamExpr)]>,
46
}
47
impl DynamicGroupBy {
48
pub fn new(
49
schema: Arc<Schema>,
50
options: DynamicGroupOptions,
51
aggs: Arc<[(PlSmallStr, StreamExpr)]>,
52
slice: Option<(IdxSize, IdxSize)>,
53
) -> PolarsResult<Self> {
54
let DynamicGroupOptions {
55
index_column,
56
every,
57
period,
58
offset,
59
label,
60
include_boundaries,
61
closed_window,
62
start_by,
63
} = options;
64
65
polars_ensure!(!every.negative(), ComputeError: "'every' argument must be positive");
66
67
let (index_column_idx, _, index_dtype) = schema.get_full(&index_column).unwrap();
68
ensure_duration_matches_dtype(every, index_dtype, "every")?;
69
ensure_duration_matches_dtype(period, index_dtype, "period")?;
70
ensure_duration_matches_dtype(offset, index_dtype, "offset")?;
71
72
use DataType as DT;
73
let (tu, tz) = match index_dtype {
74
DT::Datetime(tu, tz) => (*tu, tz.clone()),
75
DT::Date => (TimeUnit::Microseconds, None),
76
DT::Int64 | DT::Int32 => (TimeUnit::Nanoseconds, None),
77
dt => polars_bail!(
78
ComputeError:
79
"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64 }}, got {}",
80
dt
81
),
82
};
83
84
let buf_df = DataFrame::empty_with_arc_schema(schema.clone());
85
let buf_index_column =
86
Column::new_empty(index_column.clone(), &DT::Datetime(tu, tz.clone()));
87
88
// @NOTE: This is a bit strange since it ignores errors, but it mirrors the in-memory
89
// engine.
90
let tz = tz.and_then(|tz| tz.parse::<Tz>().ok());
91
let windower = GroupByDynamicWindower::new(
92
period,
93
offset,
94
every,
95
start_by,
96
closed_window,
97
tu,
98
tz,
99
include_boundaries || matches!(label, Label::Left),
100
include_boundaries || matches!(label, Label::Right),
101
);
102
103
let (slice_offset, slice_length) = slice.unwrap_or((0, IdxSize::MAX));
104
105
Ok(Self {
106
buf_df,
107
108
buf_df_offset: 0,
109
buf_index_column,
110
seq: MorselSeq::default(),
111
112
slice_offset,
113
slice_length,
114
115
group_by: None,
116
index_column,
117
index_column_idx,
118
label,
119
include_boundaries,
120
windower,
121
aggs,
122
})
123
}
124
125
#[expect(clippy::too_many_arguments)]
126
async fn evaluate_one(
127
windows: Vec<[IdxSize; 2]>,
128
lower_bound: Vec<i64>,
129
upper_bound: Vec<i64>,
130
aggs: &[(PlSmallStr, StreamExpr)],
131
state: &ExecutionState,
132
mut df: DataFrame,
133
134
group_by: Option<&str>,
135
index_column_name: &str,
136
index_column_idx: usize,
137
label: Label,
138
include_boundaries: bool,
139
) -> PolarsResult<DataFrame> {
140
let height = windows.len();
141
let groups = GroupsType::new_slice(windows, true, true).into_sliceable();
142
143
// @NOTE:
144
// Rechunk so we can use specialized rolling/dynamic kernels.
145
df.rechunk_mut();
146
147
let mut columns =
148
Vec::with_capacity(if include_boundaries { 2 } else { 0 } + 1 + aggs.len());
149
150
// Construct `lower_bound`, `upper_bound` and `key` columns that might be included in the
151
// output dataframe.
152
{
153
let mut lower = Int64Chunked::new_vec(PlSmallStr::from_static(LB_NAME), lower_bound);
154
let mut upper = Int64Chunked::new_vec(PlSmallStr::from_static(UB_NAME), upper_bound);
155
if group_by.is_none() {
156
lower.set_sorted_flag(IsSorted::Ascending);
157
upper.set_sorted_flag(IsSorted::Ascending);
158
}
159
let mut lower = lower.into_column();
160
let mut upper = upper.into_column();
161
162
let index_column = &df.get_columns()[index_column_idx];
163
let index_dtype = index_column.dtype();
164
let mut bound_dtype_physical = index_dtype.to_physical();
165
let mut bound_dtype = index_dtype;
166
if index_dtype.is_date() {
167
bound_dtype = &DataType::Datetime(TimeUnit::Microseconds, None);
168
bound_dtype_physical = DataType::Int64;
169
}
170
lower = lower.cast(&bound_dtype_physical).unwrap();
171
upper = upper.cast(&bound_dtype_physical).unwrap();
172
(lower, upper) = unsafe {
173
(
174
lower.from_physical_unchecked(bound_dtype)?,
175
upper.from_physical_unchecked(bound_dtype)?,
176
)
177
};
178
179
let key = match label {
180
Label::DataPoint => unsafe { index_column.agg_first(&groups) },
181
Label::Left => lower
182
.cast(index_dtype)
183
.unwrap()
184
.with_name(index_column_name.into()),
185
Label::Right => upper
186
.cast(index_dtype)
187
.unwrap()
188
.with_name(index_column_name.into()),
189
};
190
191
if include_boundaries {
192
columns.extend([lower, upper]);
193
}
194
columns.push(key);
195
}
196
197
for (name, agg) in aggs.iter() {
198
let mut agg = agg.evaluate_on_groups(&df, &groups, state).await?;
199
let agg = agg.finalize();
200
columns.push(agg.with_name(name.clone()));
201
}
202
203
Ok(unsafe { DataFrame::new_no_checks(height, columns) })
204
}
205
206
/// Progress the state and get the next available evaluation windows, data and key.
207
fn next_windows(&mut self, finalize: bool) -> PolarsResult<Option<NextWindows>> {
208
let mut windows = Vec::new();
209
let mut lower_bound = Vec::new();
210
let mut upper_bound = Vec::new();
211
212
let num_retired = if finalize {
213
self.windower
214
.finalize(&mut windows, &mut lower_bound, &mut upper_bound);
215
self.buf_df.height() as IdxSize
216
} else {
217
let mut offset = self.windower.num_seen() - self.buf_df_offset;
218
let ca = self.buf_index_column.datetime()?;
219
for arr in ca.physical().downcast_iter() {
220
let arr_len = arr.len() as IdxSize;
221
if offset >= arr_len {
222
offset -= arr_len;
223
continue;
224
}
225
226
self.windower.insert(
227
&arr.values().as_slice()[offset as usize..],
228
&mut windows,
229
&mut lower_bound,
230
&mut upper_bound,
231
)?;
232
offset = offset.saturating_sub(arr_len);
233
}
234
self.windower.lowest_needed_index() - self.buf_df_offset
235
};
236
237
if windows.is_empty() {
238
if num_retired > 0 {
239
self.buf_df = self.buf_df.slice(num_retired as i64, usize::MAX);
240
self.buf_index_column = self.buf_index_column.slice(num_retired as i64, usize::MAX);
241
self.buf_df_offset += num_retired;
242
}
243
244
return Ok(None);
245
}
246
247
// Prune the data that is not covered by the windows and update the windows accordingly.
248
let offset = windows[0][0];
249
let end = windows.last().unwrap();
250
let end = end[0] + end[1];
251
252
if self.slice_offset as usize > windows.len() {
253
self.slice_offset -= windows.len() as IdxSize;
254
windows.clear();
255
lower_bound.clear();
256
upper_bound.clear();
257
} else if self.slice_offset > 0 {
258
let offset = self.slice_offset as usize;
259
self.slice_offset = self.slice_offset.saturating_sub(windows.len() as IdxSize);
260
windows.drain(..offset);
261
lower_bound.drain(..offset.min(lower_bound.len()));
262
upper_bound.drain(..offset.min(upper_bound.len()));
263
}
264
265
let trunc_length = windows.len().min(self.slice_length as usize);
266
windows.truncate(trunc_length);
267
lower_bound.truncate(trunc_length);
268
upper_bound.truncate(trunc_length);
269
self.slice_length -= windows.len() as IdxSize;
270
271
windows.iter_mut().for_each(|[s, _]| *s -= offset);
272
let data = self.buf_df.slice(
273
(offset - self.buf_df_offset) as i64,
274
(end - self.buf_df_offset) as usize,
275
);
276
277
self.buf_df = self.buf_df.slice(num_retired as i64, usize::MAX);
278
self.buf_index_column = self.buf_index_column.slice(num_retired as i64, usize::MAX);
279
self.buf_df_offset += num_retired;
280
281
if windows.is_empty() {
282
return Ok(None);
283
}
284
285
Ok(Some((windows, lower_bound, upper_bound, data)))
286
}
287
}
288
289
impl ComputeNode for DynamicGroupBy {
290
fn name(&self) -> &str {
291
"dynamic-group-by"
292
}
293
294
fn update_state(
295
&mut self,
296
recv: &mut [PortState],
297
send: &mut [PortState],
298
_state: &StreamingExecutionState,
299
) -> PolarsResult<()> {
300
assert!(recv.len() == 1 && send.len() == 1);
301
302
if self.slice_length == 0 {
303
recv[0] = PortState::Done;
304
send[0] = PortState::Done;
305
std::mem::take(&mut self.buf_df);
306
return Ok(());
307
}
308
309
if send[0] == PortState::Done {
310
recv[0] = PortState::Done;
311
std::mem::take(&mut self.buf_df);
312
} else if recv[0] == PortState::Done {
313
if self.buf_df.is_empty() {
314
send[0] = PortState::Done;
315
} else {
316
send[0] = PortState::Ready;
317
}
318
} else {
319
recv.swap_with_slice(send);
320
}
321
322
Ok(())
323
}
324
325
fn spawn<'env, 's>(
326
&'env mut self,
327
scope: &'s TaskScope<'s, 'env>,
328
recv_ports: &mut [Option<RecvPort<'_>>],
329
send_ports: &mut [Option<SendPort<'_>>],
330
state: &'s StreamingExecutionState,
331
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
332
) {
333
assert!(recv_ports.len() == 1 && send_ports.len() == 1);
334
335
let Some(recv) = recv_ports[0].take() else {
336
// We no longer have to receive data. Finalize and send all remaining data.
337
assert!(!self.buf_df.is_empty());
338
assert!(self.slice_length > 0);
339
let mut send = send_ports[0].take().unwrap().serial();
340
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
341
if let Some((windows, lower_bound, upper_bound, df)) = self.next_windows(true)? {
342
let df = Self::evaluate_one(
343
windows,
344
lower_bound,
345
upper_bound,
346
&self.aggs,
347
&state.in_memory_exec_state,
348
df,
349
self.group_by.as_deref(),
350
self.index_column.as_str(),
351
self.index_column_idx,
352
self.label,
353
self.include_boundaries,
354
)
355
.await?;
356
357
_ = send
358
.send(Morsel::new(df, self.seq.successor(), SourceToken::new()))
359
.await;
360
}
361
362
self.buf_df = self.buf_df.clear();
363
Ok(())
364
}));
365
return;
366
};
367
368
let mut recv = recv.serial();
369
let send = send_ports[0].take().unwrap().parallel();
370
371
let (mut distributor, rxs) =
372
distributor_channel::<(Morsel, Vec<[IdxSize; 2]>, Vec<i64>, Vec<i64>)>(
373
send.len(),
374
*DEFAULT_DISTRIBUTOR_BUFFER_SIZE,
375
);
376
377
// Worker tasks.
378
//
379
// These evaluate the aggregations.
380
join_handles.extend(rxs.into_iter().zip(send).map(|(mut rx, mut tx)| {
381
let wg = WaitGroup::default();
382
let aggs = self.aggs.clone();
383
let state = state.in_memory_exec_state.split();
384
385
let group_by = self.group_by.clone();
386
let index_column = self.index_column.clone();
387
let index_column_idx = self.index_column_idx;
388
let label = self.label;
389
let include_boundaries = self.include_boundaries;
390
391
scope.spawn_task(TaskPriority::High, async move {
392
while let Ok((mut morsel, windows, lower_bound, upper_bound)) = rx.recv().await {
393
morsel = morsel
394
.async_try_map::<PolarsError, _, _>(async |df| {
395
Self::evaluate_one(
396
windows,
397
lower_bound,
398
upper_bound,
399
&aggs,
400
&state,
401
df,
402
group_by.as_deref(),
403
index_column.as_str(),
404
index_column_idx,
405
label,
406
include_boundaries,
407
)
408
.await
409
})
410
.await?;
411
morsel.set_consume_token(wg.token());
412
413
if tx.send(morsel).await.is_err() {
414
break;
415
}
416
wg.wait().await;
417
}
418
419
Ok(())
420
})
421
}));
422
423
// Distributor task.
424
//
425
// This finds boundaries to distribute to worker threads over.
426
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
427
while let Ok(morsel) = recv.recv().await
428
&& self.slice_length > 0
429
{
430
let (df, seq, source_token, wait_token) = morsel.into_inner();
431
self.seq = seq;
432
drop(wait_token);
433
434
if df.height() == 0 {
435
continue;
436
}
437
438
let morsel_index_column = df.column(&self.index_column)?;
439
polars_ensure!(
440
morsel_index_column.null_count() == 0,
441
ComputeError: "null values in `group_by_dynamic` not supported, fill nulls."
442
);
443
444
use DataType as DT;
445
let morsel_index_column = match morsel_index_column.dtype() {
446
DT::Datetime(_, _) => morsel_index_column.clone(),
447
DT::Date => {
448
morsel_index_column.cast(&DT::Datetime(TimeUnit::Microseconds, None))?
449
},
450
DT::Int32 => morsel_index_column
451
.cast(&DT::Int64)?
452
.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?,
453
DT::Int64 => {
454
morsel_index_column.cast(&DT::Datetime(TimeUnit::Nanoseconds, None))?
455
},
456
_ => unreachable!(),
457
};
458
459
self.buf_df.vstack_mut_owned(df)?;
460
self.buf_index_column.append_owned(morsel_index_column)?;
461
462
if let Some((windows, lower_bound, upper_bound, df)) = self.next_windows(false)? {
463
if distributor
464
.send((
465
Morsel::new(df, seq, source_token),
466
windows,
467
lower_bound,
468
upper_bound,
469
))
470
.await
471
.is_err()
472
{
473
break;
474
}
475
}
476
}
477
478
Ok(())
479
}));
480
}
481
}
482
483