Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-time/src/group_by/dynamic.rs
8430 views
1
use arrow::legacy::time_zone::Tz;
2
use polars_core::POOL;
3
use polars_core::prelude::*;
4
use polars_core::series::IsSorted;
5
use polars_core::utils::flatten::flatten_par;
6
use polars_ops::series::SeriesMethods;
7
use polars_utils::itertools::Itertools;
8
use polars_utils::pl_str::PlSmallStr;
9
use polars_utils::slice::SortedSlice;
10
use rayon::prelude::*;
11
#[cfg(feature = "serde")]
12
use serde::{Deserialize, Serialize};
13
14
use crate::prelude::*;
15
16
#[repr(transparent)]
17
struct Wrap<T>(pub T);
18
19
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
22
pub struct DynamicGroupOptions {
23
/// Time or index column.
24
pub index_column: PlSmallStr,
25
/// Start a window at this interval.
26
pub every: Duration,
27
/// Window duration.
28
pub period: Duration,
29
/// Offset window boundaries.
30
pub offset: Duration,
31
/// Truncate the time column values to the window.
32
pub label: Label,
33
/// Add the boundaries to the DataFrame.
34
pub include_boundaries: bool,
35
pub closed_window: ClosedWindow,
36
pub start_by: StartBy,
37
}
38
39
impl Default for DynamicGroupOptions {
40
fn default() -> Self {
41
Self {
42
index_column: "".into(),
43
every: Duration::new(1),
44
period: Duration::new(1),
45
offset: Duration::new(1),
46
label: Label::Left,
47
include_boundaries: false,
48
closed_window: ClosedWindow::Left,
49
start_by: Default::default(),
50
}
51
}
52
}
53
54
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
55
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
57
pub struct RollingGroupOptions {
58
/// Time or index column.
59
pub index_column: PlSmallStr,
60
/// Window duration.
61
pub period: Duration,
62
pub offset: Duration,
63
pub closed_window: ClosedWindow,
64
}
65
66
impl Default for RollingGroupOptions {
67
fn default() -> Self {
68
Self {
69
index_column: "".into(),
70
period: Duration::new(1),
71
offset: Duration::new(1),
72
closed_window: ClosedWindow::Left,
73
}
74
}
75
}
76
77
fn check_sortedness_slice(v: &[i64]) -> PolarsResult<()> {
78
polars_ensure!(v.is_sorted_ascending(), ComputeError: "input data is not sorted");
79
Ok(())
80
}
81
82
pub const LB_NAME: &str = "_lower_boundary";
83
pub const UB_NAME: &str = "_upper_boundary";
84
85
pub trait PolarsTemporalGroupby {
86
fn rolling(
87
&self,
88
group_by: Option<GroupsSlice>,
89
options: &RollingGroupOptions,
90
) -> PolarsResult<(Column, GroupPositions)>;
91
92
fn group_by_dynamic(
93
&self,
94
group_by: Option<GroupsSlice>,
95
options: &DynamicGroupOptions,
96
) -> PolarsResult<(Column, Vec<Column>, GroupPositions)>;
97
}
98
99
impl PolarsTemporalGroupby for DataFrame {
100
fn rolling(
101
&self,
102
group_by: Option<GroupsSlice>,
103
options: &RollingGroupOptions,
104
) -> PolarsResult<(Column, GroupPositions)> {
105
Wrap(self).rolling(group_by, options)
106
}
107
108
fn group_by_dynamic(
109
&self,
110
group_by: Option<GroupsSlice>,
111
options: &DynamicGroupOptions,
112
) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {
113
Wrap(self).group_by_dynamic(group_by, options)
114
}
115
}
116
117
impl Wrap<&DataFrame> {
118
fn rolling(
119
&self,
120
group_by: Option<GroupsSlice>,
121
options: &RollingGroupOptions,
122
) -> PolarsResult<(Column, GroupPositions)> {
123
polars_ensure!(
124
!options.period.is_zero() && !options.period.negative,
125
ComputeError:
126
"rolling window period should be strictly positive",
127
);
128
let time = self.0.column(&options.index_column)?.clone();
129
if group_by.is_none() {
130
// If by is given, the column must be sorted in the 'by' arg, which we can not check now
131
// this will be checked when the groups are materialized.
132
time.as_materialized_series().ensure_sorted_arg("rolling")?;
133
}
134
let time_type = time.dtype();
135
136
polars_ensure!(time.null_count() == 0, ComputeError: "null values in `rolling` not supported, fill nulls.");
137
ensure_duration_matches_dtype(options.period, time_type, "period")?;
138
ensure_duration_matches_dtype(options.offset, time_type, "offset")?;
139
140
use DataType::*;
141
let (dt, tu, tz): (Column, TimeUnit, Option<TimeZone>) = match time_type {
142
Datetime(tu, tz) => (time.clone(), *tu, tz.clone()),
143
Date => (
144
time.cast(&Datetime(TimeUnit::Microseconds, None))?,
145
TimeUnit::Microseconds,
146
None,
147
),
148
UInt32 | UInt64 | Int32 => {
149
let time_type_dt = Datetime(TimeUnit::Nanoseconds, None);
150
let dt = time.cast(&Int64).unwrap().cast(&time_type_dt).unwrap();
151
let (out, gt) = self.impl_rolling(
152
dt,
153
group_by,
154
options,
155
TimeUnit::Nanoseconds,
156
None,
157
&time_type_dt,
158
)?;
159
let out = out.cast(&Int64).unwrap().cast(time_type).unwrap();
160
return Ok((out, gt));
161
},
162
Int64 => {
163
let time_type = Datetime(TimeUnit::Nanoseconds, None);
164
let dt = time.cast(&time_type).unwrap();
165
let (out, gt) = self.impl_rolling(
166
dt,
167
group_by,
168
options,
169
TimeUnit::Nanoseconds,
170
None,
171
&time_type,
172
)?;
173
let out = out.cast(&Int64).unwrap();
174
return Ok((out, gt));
175
},
176
dt => polars_bail!(
177
ComputeError:
178
"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}",
179
dt
180
),
181
};
182
match tz {
183
#[cfg(feature = "timezones")]
184
Some(tz) => {
185
self.impl_rolling(dt, group_by, options, tu, tz.parse::<Tz>().ok(), time_type)
186
},
187
_ => self.impl_rolling(dt, group_by, options, tu, None, time_type),
188
}
189
}
190
191
/// Returns: time_keys, keys, groupsproxy.
192
fn group_by_dynamic(
193
&self,
194
group_by: Option<GroupsSlice>,
195
options: &DynamicGroupOptions,
196
) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {
197
let time = self.0.column(&options.index_column)?.rechunk();
198
if group_by.is_none() {
199
// If by is given, the column must be sorted in the 'by' arg, which we can not check now
200
// this will be checked when the groups are materialized.
201
time.as_materialized_series()
202
.ensure_sorted_arg("group_by_dynamic")?;
203
}
204
let time_type = time.dtype();
205
206
polars_ensure!(time.null_count() == 0, ComputeError: "null values in dynamic group_by not supported, fill nulls.");
207
ensure_duration_matches_dtype(options.every, time_type, "every")?;
208
ensure_duration_matches_dtype(options.offset, time_type, "offset")?;
209
ensure_duration_matches_dtype(options.period, time_type, "period")?;
210
211
use DataType::*;
212
let (dt, tu) = match time_type {
213
Datetime(tu, _) => (time.clone(), *tu),
214
Date => (
215
time.cast(&Datetime(TimeUnit::Microseconds, None))?,
216
TimeUnit::Microseconds,
217
),
218
Int32 => {
219
let time_type = Datetime(TimeUnit::Nanoseconds, None);
220
let dt = time.cast(&Int64).unwrap().cast(&time_type).unwrap();
221
let (out, mut keys, gt) = self.impl_group_by_dynamic(
222
dt,
223
group_by,
224
options,
225
TimeUnit::Nanoseconds,
226
&time_type,
227
)?;
228
let out = out.cast(&Int64).unwrap().cast(&Int32).unwrap();
229
for k in &mut keys {
230
if k.name().as_str() == UB_NAME || k.name().as_str() == LB_NAME {
231
*k = k.cast(&Int64).unwrap().cast(&Int32).unwrap()
232
}
233
}
234
return Ok((out, keys, gt));
235
},
236
Int64 => {
237
let time_type = Datetime(TimeUnit::Nanoseconds, None);
238
let dt = time.cast(&time_type).unwrap();
239
let (out, mut keys, gt) = self.impl_group_by_dynamic(
240
dt,
241
group_by,
242
options,
243
TimeUnit::Nanoseconds,
244
&time_type,
245
)?;
246
let out = out.cast(&Int64).unwrap();
247
for k in &mut keys {
248
if k.name().as_str() == UB_NAME || k.name().as_str() == LB_NAME {
249
*k = k.cast(&Int64).unwrap()
250
}
251
}
252
return Ok((out, keys, gt));
253
},
254
dt => polars_bail!(
255
ComputeError:
256
"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64 }}, got {}",
257
dt
258
),
259
};
260
self.impl_group_by_dynamic(dt, group_by, options, tu, time_type)
261
}
262
263
fn impl_group_by_dynamic(
264
&self,
265
mut dt: Column,
266
group_by: Option<GroupsSlice>,
267
options: &DynamicGroupOptions,
268
tu: TimeUnit,
269
time_type: &DataType,
270
) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {
271
polars_ensure!(!options.every.negative, ComputeError: "'every' argument must be positive");
272
if dt.is_empty() {
273
return dt.cast(time_type).map(|s| (s, vec![], Default::default()));
274
}
275
276
// A requirement for the index so we can set this such that downstream code has this info.
277
dt.set_sorted_flag(IsSorted::Ascending);
278
279
let w = Window::new(options.every, options.period, options.offset);
280
let dt = dt.datetime().unwrap();
281
let tz = dt.time_zone();
282
283
let mut lower_bound = None;
284
let mut upper_bound = None;
285
286
let mut include_lower_bound = false;
287
let mut include_upper_bound = false;
288
289
if options.include_boundaries {
290
include_lower_bound = true;
291
include_upper_bound = true;
292
}
293
if options.label == Label::Left {
294
include_lower_bound = true;
295
} else if options.label == Label::Right {
296
include_upper_bound = true;
297
}
298
299
let mut update_bounds =
300
|lower: Vec<i64>, upper: Vec<i64>| match (&mut lower_bound, &mut upper_bound) {
301
(None, None) => {
302
lower_bound = Some(lower);
303
upper_bound = Some(upper);
304
},
305
(Some(lower_bound), Some(upper_bound)) => {
306
lower_bound.extend_from_slice(&lower);
307
upper_bound.extend_from_slice(&upper);
308
},
309
_ => unreachable!(),
310
};
311
312
let overlapping = match options.closed_window {
313
ClosedWindow::Both => options.period >= options.every,
314
_ => options.period > options.every,
315
};
316
317
let groups = if let Some(groups) = group_by.as_ref() {
318
let vals = dt.physical().downcast_iter().next().unwrap();
319
let ts = vals.values().as_slice();
320
321
let iter = groups.par_iter().map(|[start, len]| {
322
let group_offset = *start;
323
let start = *start as usize;
324
let end = start + *len as usize;
325
let values = &ts[start..end];
326
check_sortedness_slice(values)?;
327
328
let (groups, lower, upper) = group_by_windows(
329
w,
330
values,
331
options.closed_window,
332
tu,
333
tz,
334
include_lower_bound,
335
include_upper_bound,
336
options.start_by,
337
)?;
338
339
PolarsResult::Ok((
340
groups
341
.iter()
342
.map(|[start, len]| [*start + group_offset, *len])
343
.collect_vec(),
344
lower,
345
upper,
346
))
347
});
348
349
let res = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
350
let groups = res.iter().map(|g| &g.0).collect_vec();
351
let lower = res.iter().map(|g| &g.1).collect_vec();
352
let upper = res.iter().map(|g| &g.2).collect_vec();
353
354
let ((groups, upper), lower) = POOL.install(|| {
355
rayon::join(
356
|| rayon::join(|| flatten_par(&groups), || flatten_par(&upper)),
357
|| flatten_par(&lower),
358
)
359
});
360
361
update_bounds(lower, upper);
362
PolarsResult::Ok(GroupsType::new_slice(groups, overlapping, true))
363
} else {
364
let vals = dt.physical().downcast_iter().next().unwrap();
365
let ts = vals.values().as_slice();
366
let (groups, lower, upper) = group_by_windows(
367
w,
368
ts,
369
options.closed_window,
370
tu,
371
tz,
372
include_lower_bound,
373
include_upper_bound,
374
options.start_by,
375
)?;
376
update_bounds(lower, upper);
377
PolarsResult::Ok(GroupsType::new_slice(groups, overlapping, true))
378
}?;
379
// note that if 'group_by' is none we can be sure that the index column, the lower column and the
380
// upper column remain/are sorted
381
382
let dt = unsafe { dt.clone().into_series().agg_first(&groups) };
383
let mut dt = dt.datetime().unwrap().physical().clone();
384
385
let lower =
386
lower_bound.map(|lower| Int64Chunked::new_vec(PlSmallStr::from_static(LB_NAME), lower));
387
let upper =
388
upper_bound.map(|upper| Int64Chunked::new_vec(PlSmallStr::from_static(UB_NAME), upper));
389
390
if options.label == Label::Left {
391
let mut lower = lower.clone().unwrap();
392
if group_by.is_none() {
393
lower.set_sorted_flag(IsSorted::Ascending)
394
}
395
dt = lower.with_name(dt.name().clone());
396
} else if options.label == Label::Right {
397
let mut upper = upper.clone().unwrap();
398
if group_by.is_none() {
399
upper.set_sorted_flag(IsSorted::Ascending)
400
}
401
dt = upper.with_name(dt.name().clone());
402
}
403
404
let mut bounds = vec![];
405
if let (true, Some(mut lower), Some(mut upper)) = (options.include_boundaries, lower, upper)
406
{
407
if group_by.is_none() {
408
lower.set_sorted_flag(IsSorted::Ascending);
409
upper.set_sorted_flag(IsSorted::Ascending);
410
}
411
bounds.push(lower.into_datetime(tu, tz.clone()).into_column());
412
bounds.push(upper.into_datetime(tu, tz.clone()).into_column());
413
}
414
415
dt.into_datetime(tu, None)
416
.into_column()
417
.cast(time_type)
418
.map(|s| (s, bounds, groups.into_sliceable()))
419
}
420
421
/// Returns: time_keys, keys, groupsproxy
422
fn impl_rolling(
423
&self,
424
dt: Column,
425
group_by: Option<GroupsSlice>,
426
options: &RollingGroupOptions,
427
tu: TimeUnit,
428
tz: Option<Tz>,
429
time_type: &DataType,
430
) -> PolarsResult<(Column, GroupPositions)> {
431
let mut dt = dt.rechunk();
432
433
let groups = if let Some(groups) = group_by {
434
let dt = dt.datetime().unwrap();
435
let vals = dt.physical().downcast_iter().next().unwrap();
436
let ts = vals.values().as_slice();
437
438
let iter = groups.into_par_iter().map(|[start, len]| {
439
let group_offset = start;
440
let start = start as usize;
441
let end = start + len as usize;
442
let values = &ts[start..end];
443
check_sortedness_slice(values)?;
444
445
let group = group_by_values(
446
options.period,
447
options.offset,
448
values,
449
options.closed_window,
450
tu,
451
tz,
452
)?;
453
454
PolarsResult::Ok(
455
group
456
.iter()
457
.map(|[start, len]| [*start + group_offset, *len])
458
.collect_vec(),
459
)
460
});
461
462
let groups = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
463
let groups = POOL.install(|| flatten_par(&groups));
464
PolarsResult::Ok(GroupsType::new_slice(groups, true, true))
465
} else {
466
// a requirement for the index
467
// so we can set this such that downstream code has this info
468
dt.set_sorted_flag(IsSorted::Ascending);
469
let dt = dt.datetime().unwrap();
470
let vals = dt.physical().downcast_iter().next().unwrap();
471
let ts = vals.values().as_slice();
472
let groups = group_by_values(
473
options.period,
474
options.offset,
475
ts,
476
options.closed_window,
477
tu,
478
tz,
479
)?;
480
PolarsResult::Ok(GroupsType::new_slice(groups, true, true))
481
}?;
482
483
let dt = dt.cast(time_type).unwrap();
484
485
Ok((dt, groups.into_sliceable()))
486
}
487
}
488
489
#[cfg(test)]
490
mod test {
491
use polars_compute::rolling::QuantileMethod;
492
use polars_ops::prelude::*;
493
494
use super::*;
495
496
#[test]
497
fn test_rolling_group_by_tu() -> PolarsResult<()> {
498
// test multiple time units
499
for tu in [
500
TimeUnit::Nanoseconds,
501
TimeUnit::Microseconds,
502
TimeUnit::Milliseconds,
503
] {
504
let mut date = StringChunked::new(
505
"dt".into(),
506
[
507
"2020-01-01 13:45:48",
508
"2020-01-01 16:42:13",
509
"2020-01-01 16:45:09",
510
"2020-01-02 18:12:48",
511
"2020-01-03 19:45:32",
512
"2020-01-08 23:16:43",
513
],
514
)
515
.as_datetime(
516
None,
517
tu,
518
false,
519
false,
520
None,
521
&StringChunked::from_iter(std::iter::once("raise")),
522
)?
523
.into_column();
524
date.set_sorted_flag(IsSorted::Ascending);
525
let a = Column::new("a".into(), [3, 7, 5, 9, 2, 1]);
526
let df = DataFrame::new_infer_height(vec![date, a.clone()])?;
527
528
let (_, groups) = df
529
.rolling(
530
None,
531
&RollingGroupOptions {
532
index_column: "dt".into(),
533
period: Duration::parse("2d"),
534
offset: Duration::parse("-2d"),
535
closed_window: ClosedWindow::Right,
536
},
537
)
538
.unwrap();
539
540
let sum = unsafe { a.agg_sum(&groups) };
541
let expected = Column::new("".into(), [3, 10, 15, 24, 11, 1]);
542
assert_eq!(sum, expected);
543
}
544
545
Ok(())
546
}
547
548
#[test]
549
fn test_rolling_group_by_aggs() -> PolarsResult<()> {
550
let mut date = StringChunked::new(
551
"dt".into(),
552
[
553
"2020-01-01 13:45:48",
554
"2020-01-01 16:42:13",
555
"2020-01-01 16:45:09",
556
"2020-01-02 18:12:48",
557
"2020-01-03 19:45:32",
558
"2020-01-08 23:16:43",
559
],
560
)
561
.as_datetime(
562
None,
563
TimeUnit::Milliseconds,
564
false,
565
false,
566
None,
567
&StringChunked::from_iter(std::iter::once("raise")),
568
)?
569
.into_column();
570
date.set_sorted_flag(IsSorted::Ascending);
571
572
let a = Column::new("a".into(), [3, 7, 5, 9, 2, 1]);
573
let df = DataFrame::new_infer_height(vec![date, a.clone()])?;
574
575
let (_, groups) = df
576
.rolling(
577
None,
578
&RollingGroupOptions {
579
index_column: "dt".into(),
580
period: Duration::parse("2d"),
581
offset: Duration::parse("-2d"),
582
closed_window: ClosedWindow::Right,
583
},
584
)
585
.unwrap();
586
587
let nulls = Series::new(
588
"".into(),
589
[Some(3), Some(7), None, Some(9), Some(2), Some(1)],
590
);
591
592
let min = unsafe { a.as_materialized_series().agg_min(&groups) };
593
let expected = Series::new("".into(), [3, 3, 3, 3, 2, 1]);
594
assert_eq!(min, expected);
595
596
// Expected for nulls is equality.
597
let min = unsafe { nulls.agg_min(&groups) };
598
assert_eq!(min, expected);
599
600
let max = unsafe { a.as_materialized_series().agg_max(&groups) };
601
let expected = Series::new("".into(), [3, 7, 7, 9, 9, 1]);
602
assert_eq!(max, expected);
603
604
let max = unsafe { nulls.agg_max(&groups) };
605
assert_eq!(max, expected);
606
607
let var = unsafe { a.as_materialized_series().agg_var(&groups, 1) };
608
let expected = Series::new(
609
"".into(),
610
[0.0, 8.0, 4.000000000000002, 6.666666666666667, 24.5, 0.0],
611
);
612
assert!(abs(&(var - expected)?).unwrap().lt(1e-12).unwrap().all());
613
614
let var = unsafe { nulls.agg_var(&groups, 1) };
615
let expected = Series::new("".into(), [0.0, 8.0, 8.0, 9.333333333333343, 24.5, 0.0]);
616
assert!(abs(&(var - expected)?).unwrap().lt(1e-12).unwrap().all());
617
618
let quantile = unsafe {
619
a.as_materialized_series()
620
.agg_quantile(&groups, 0.5, QuantileMethod::Linear)
621
};
622
let expected = Series::new("".into(), [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]);
623
assert_eq!(quantile, expected);
624
625
let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileMethod::Linear) };
626
let expected = Series::new("".into(), [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]);
627
assert_eq!(quantile, expected);
628
629
Ok(())
630
}
631
}
632
633