Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-time/src/group_by/dynamic.rs
6939 views
1
use arrow::legacy::time_zone::Tz;
2
use polars_core::POOL;
3
use polars_core::prelude::*;
4
use polars_core::series::IsSorted;
5
use polars_core::utils::flatten::flatten_par;
6
use polars_ops::series::SeriesMethods;
7
use polars_utils::itertools::Itertools;
8
use polars_utils::pl_str::PlSmallStr;
9
use polars_utils::slice::SortedSlice;
10
use rayon::prelude::*;
11
#[cfg(feature = "serde")]
12
use serde::{Deserialize, Serialize};
13
14
use crate::prelude::*;
15
16
#[repr(transparent)]
17
struct Wrap<T>(pub T);
18
19
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
22
pub struct DynamicGroupOptions {
23
/// Time or index column.
24
pub index_column: PlSmallStr,
25
/// Start a window at this interval.
26
pub every: Duration,
27
/// Window duration.
28
pub period: Duration,
29
/// Offset window boundaries.
30
pub offset: Duration,
31
/// Truncate the time column values to the window.
32
pub label: Label,
33
/// Add the boundaries to the DataFrame.
34
pub include_boundaries: bool,
35
pub closed_window: ClosedWindow,
36
pub start_by: StartBy,
37
}
38
39
impl Default for DynamicGroupOptions {
40
fn default() -> Self {
41
Self {
42
index_column: "".into(),
43
every: Duration::new(1),
44
period: Duration::new(1),
45
offset: Duration::new(1),
46
label: Label::Left,
47
include_boundaries: false,
48
closed_window: ClosedWindow::Left,
49
start_by: Default::default(),
50
}
51
}
52
}
53
54
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
55
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
57
pub struct RollingGroupOptions {
58
/// Time or index column.
59
pub index_column: PlSmallStr,
60
/// Window duration.
61
pub period: Duration,
62
pub offset: Duration,
63
pub closed_window: ClosedWindow,
64
}
65
66
impl Default for RollingGroupOptions {
67
fn default() -> Self {
68
Self {
69
index_column: "".into(),
70
period: Duration::new(1),
71
offset: Duration::new(1),
72
closed_window: ClosedWindow::Left,
73
}
74
}
75
}
76
77
fn check_sortedness_slice(v: &[i64]) -> PolarsResult<()> {
78
polars_ensure!(v.is_sorted_ascending(), ComputeError: "input data is not sorted");
79
Ok(())
80
}
81
82
const LB_NAME: &str = "_lower_boundary";
83
const UP_NAME: &str = "_upper_boundary";
84
85
pub trait PolarsTemporalGroupby {
86
fn rolling(
87
&self,
88
group_by: Option<GroupsSlice>,
89
options: &RollingGroupOptions,
90
) -> PolarsResult<(Column, GroupPositions)>;
91
92
fn group_by_dynamic(
93
&self,
94
group_by: Option<GroupsSlice>,
95
options: &DynamicGroupOptions,
96
) -> PolarsResult<(Column, Vec<Column>, GroupPositions)>;
97
}
98
99
impl PolarsTemporalGroupby for DataFrame {
100
fn rolling(
101
&self,
102
group_by: Option<GroupsSlice>,
103
options: &RollingGroupOptions,
104
) -> PolarsResult<(Column, GroupPositions)> {
105
Wrap(self).rolling(group_by, options)
106
}
107
108
fn group_by_dynamic(
109
&self,
110
group_by: Option<GroupsSlice>,
111
options: &DynamicGroupOptions,
112
) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {
113
Wrap(self).group_by_dynamic(group_by, options)
114
}
115
}
116
117
impl Wrap<&DataFrame> {
118
fn rolling(
119
&self,
120
group_by: Option<GroupsSlice>,
121
options: &RollingGroupOptions,
122
) -> PolarsResult<(Column, GroupPositions)> {
123
polars_ensure!(
124
!options.period.is_zero() && !options.period.negative,
125
ComputeError:
126
"rolling window period should be strictly positive",
127
);
128
let time = self.0.column(&options.index_column)?.clone();
129
if group_by.is_none() {
130
// If by is given, the column must be sorted in the 'by' arg, which we can not check now
131
// this will be checked when the groups are materialized.
132
time.as_materialized_series().ensure_sorted_arg("rolling")?;
133
}
134
let time_type = time.dtype();
135
136
polars_ensure!(time.null_count() == 0, ComputeError: "null values in `rolling` not supported, fill nulls.");
137
ensure_duration_matches_dtype(options.period, time_type, "period")?;
138
ensure_duration_matches_dtype(options.offset, time_type, "offset")?;
139
140
use DataType::*;
141
let (dt, tu, tz): (Column, TimeUnit, Option<TimeZone>) = match time_type {
142
Datetime(tu, tz) => (time.clone(), *tu, tz.clone()),
143
Date => (
144
time.cast(&Datetime(TimeUnit::Microseconds, None))?,
145
TimeUnit::Microseconds,
146
None,
147
),
148
UInt32 | UInt64 | Int32 => {
149
let time_type_dt = Datetime(TimeUnit::Nanoseconds, None);
150
let dt = time.cast(&Int64).unwrap().cast(&time_type_dt).unwrap();
151
let (out, gt) = self.impl_rolling(
152
dt,
153
group_by,
154
options,
155
TimeUnit::Nanoseconds,
156
None,
157
&time_type_dt,
158
)?;
159
let out = out.cast(&Int64).unwrap().cast(time_type).unwrap();
160
return Ok((out, gt));
161
},
162
Int64 => {
163
let time_type = Datetime(TimeUnit::Nanoseconds, None);
164
let dt = time.cast(&time_type).unwrap();
165
let (out, gt) = self.impl_rolling(
166
dt,
167
group_by,
168
options,
169
TimeUnit::Nanoseconds,
170
None,
171
&time_type,
172
)?;
173
let out = out.cast(&Int64).unwrap();
174
return Ok((out, gt));
175
},
176
dt => polars_bail!(
177
ComputeError:
178
"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}",
179
dt
180
),
181
};
182
match tz {
183
#[cfg(feature = "timezones")]
184
Some(tz) => {
185
self.impl_rolling(dt, group_by, options, tu, tz.parse::<Tz>().ok(), time_type)
186
},
187
_ => self.impl_rolling(dt, group_by, options, tu, None, time_type),
188
}
189
}
190
191
/// Returns: time_keys, keys, groupsproxy.
192
fn group_by_dynamic(
193
&self,
194
group_by: Option<GroupsSlice>,
195
options: &DynamicGroupOptions,
196
) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {
197
let time = self.0.column(&options.index_column)?.rechunk();
198
if group_by.is_none() {
199
// If by is given, the column must be sorted in the 'by' arg, which we can not check now
200
// this will be checked when the groups are materialized.
201
time.as_materialized_series()
202
.ensure_sorted_arg("group_by_dynamic")?;
203
}
204
let time_type = time.dtype();
205
206
polars_ensure!(time.null_count() == 0, ComputeError: "null values in dynamic group_by not supported, fill nulls.");
207
ensure_duration_matches_dtype(options.every, time_type, "every")?;
208
ensure_duration_matches_dtype(options.offset, time_type, "offset")?;
209
ensure_duration_matches_dtype(options.period, time_type, "period")?;
210
211
use DataType::*;
212
let (dt, tu) = match time_type {
213
Datetime(tu, _) => (time.clone(), *tu),
214
Date => (
215
time.cast(&Datetime(TimeUnit::Microseconds, None))?,
216
TimeUnit::Microseconds,
217
),
218
Int32 => {
219
let time_type = Datetime(TimeUnit::Nanoseconds, None);
220
let dt = time.cast(&Int64).unwrap().cast(&time_type).unwrap();
221
let (out, mut keys, gt) = self.impl_group_by_dynamic(
222
dt,
223
group_by,
224
options,
225
TimeUnit::Nanoseconds,
226
&time_type,
227
)?;
228
let out = out.cast(&Int64).unwrap().cast(&Int32).unwrap();
229
for k in &mut keys {
230
if k.name().as_str() == UP_NAME || k.name().as_str() == LB_NAME {
231
*k = k.cast(&Int64).unwrap().cast(&Int32).unwrap()
232
}
233
}
234
return Ok((out, keys, gt));
235
},
236
Int64 => {
237
let time_type = Datetime(TimeUnit::Nanoseconds, None);
238
let dt = time.cast(&time_type).unwrap();
239
let (out, mut keys, gt) = self.impl_group_by_dynamic(
240
dt,
241
group_by,
242
options,
243
TimeUnit::Nanoseconds,
244
&time_type,
245
)?;
246
let out = out.cast(&Int64).unwrap();
247
for k in &mut keys {
248
if k.name().as_str() == UP_NAME || k.name().as_str() == LB_NAME {
249
*k = k.cast(&Int64).unwrap()
250
}
251
}
252
return Ok((out, keys, gt));
253
},
254
dt => polars_bail!(
255
ComputeError:
256
"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64 }}, got {}",
257
dt
258
),
259
};
260
self.impl_group_by_dynamic(dt, group_by, options, tu, time_type)
261
}
262
263
fn impl_group_by_dynamic(
264
&self,
265
mut dt: Column,
266
group_by: Option<GroupsSlice>,
267
options: &DynamicGroupOptions,
268
tu: TimeUnit,
269
time_type: &DataType,
270
) -> PolarsResult<(Column, Vec<Column>, GroupPositions)> {
271
polars_ensure!(!options.every.negative, ComputeError: "'every' argument must be positive");
272
if dt.is_empty() {
273
return dt.cast(time_type).map(|s| (s, vec![], Default::default()));
274
}
275
276
// A requirement for the index so we can set this such that downstream code has this info.
277
dt.set_sorted_flag(IsSorted::Ascending);
278
279
let w = Window::new(options.every, options.period, options.offset);
280
let dt = dt.datetime().unwrap();
281
let tz = dt.time_zone();
282
283
let mut lower_bound = None;
284
let mut upper_bound = None;
285
286
let mut include_lower_bound = false;
287
let mut include_upper_bound = false;
288
289
if options.include_boundaries {
290
include_lower_bound = true;
291
include_upper_bound = true;
292
}
293
if options.label == Label::Left {
294
include_lower_bound = true;
295
} else if options.label == Label::Right {
296
include_upper_bound = true;
297
}
298
299
let mut update_bounds =
300
|lower: Vec<i64>, upper: Vec<i64>| match (&mut lower_bound, &mut upper_bound) {
301
(None, None) => {
302
lower_bound = Some(lower);
303
upper_bound = Some(upper);
304
},
305
(Some(lower_bound), Some(upper_bound)) => {
306
lower_bound.extend_from_slice(&lower);
307
upper_bound.extend_from_slice(&upper);
308
},
309
_ => unreachable!(),
310
};
311
312
let groups = if group_by.is_none() {
313
let vals = dt.physical().downcast_iter().next().unwrap();
314
let ts = vals.values().as_slice();
315
let (groups, lower, upper) = group_by_windows(
316
w,
317
ts,
318
options.closed_window,
319
tu,
320
tz,
321
include_lower_bound,
322
include_upper_bound,
323
options.start_by,
324
)?;
325
update_bounds(lower, upper);
326
PolarsResult::Ok(GroupsType::Slice {
327
groups,
328
rolling: false,
329
})
330
} else {
331
let vals = dt.physical().downcast_iter().next().unwrap();
332
let ts = vals.values().as_slice();
333
334
let groups = group_by.as_ref().unwrap();
335
336
let iter = groups.par_iter().map(|[start, len]| {
337
let group_offset = *start;
338
let start = *start as usize;
339
let end = start + *len as usize;
340
let values = &ts[start..end];
341
check_sortedness_slice(values)?;
342
343
let (groups, lower, upper) = group_by_windows(
344
w,
345
values,
346
options.closed_window,
347
tu,
348
tz,
349
include_lower_bound,
350
include_upper_bound,
351
options.start_by,
352
)?;
353
354
PolarsResult::Ok((
355
groups
356
.iter()
357
.map(|[start, len]| [*start + group_offset, *len])
358
.collect_vec(),
359
lower,
360
upper,
361
))
362
});
363
364
let res = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
365
let groups = res.iter().map(|g| &g.0).collect_vec();
366
let lower = res.iter().map(|g| &g.1).collect_vec();
367
let upper = res.iter().map(|g| &g.2).collect_vec();
368
369
let ((groups, upper), lower) = POOL.install(|| {
370
rayon::join(
371
|| rayon::join(|| flatten_par(&groups), || flatten_par(&upper)),
372
|| flatten_par(&lower),
373
)
374
});
375
376
update_bounds(lower, upper);
377
PolarsResult::Ok(GroupsType::Slice {
378
groups,
379
rolling: false,
380
})
381
}?;
382
// note that if 'group_by' is none we can be sure that the index column, the lower column and the
383
// upper column remain/are sorted
384
385
let dt = unsafe { dt.clone().into_series().agg_first(&groups) };
386
let mut dt = dt.datetime().unwrap().physical().clone();
387
388
let lower =
389
lower_bound.map(|lower| Int64Chunked::new_vec(PlSmallStr::from_static(LB_NAME), lower));
390
let upper =
391
upper_bound.map(|upper| Int64Chunked::new_vec(PlSmallStr::from_static(UP_NAME), upper));
392
393
if options.label == Label::Left {
394
let mut lower = lower.clone().unwrap();
395
if group_by.is_none() {
396
lower.set_sorted_flag(IsSorted::Ascending)
397
}
398
dt = lower.with_name(dt.name().clone());
399
} else if options.label == Label::Right {
400
let mut upper = upper.clone().unwrap();
401
if group_by.is_none() {
402
upper.set_sorted_flag(IsSorted::Ascending)
403
}
404
dt = upper.with_name(dt.name().clone());
405
}
406
407
let mut bounds = vec![];
408
if let (true, Some(mut lower), Some(mut upper)) = (options.include_boundaries, lower, upper)
409
{
410
if group_by.is_none() {
411
lower.set_sorted_flag(IsSorted::Ascending);
412
upper.set_sorted_flag(IsSorted::Ascending);
413
}
414
bounds.push(lower.into_datetime(tu, tz.clone()).into_column());
415
bounds.push(upper.into_datetime(tu, tz.clone()).into_column());
416
}
417
418
dt.into_datetime(tu, None)
419
.into_column()
420
.cast(time_type)
421
.map(|s| (s, bounds, groups.into_sliceable()))
422
}
423
424
/// Returns: time_keys, keys, groupsproxy
425
fn impl_rolling(
426
&self,
427
dt: Column,
428
group_by: Option<GroupsSlice>,
429
options: &RollingGroupOptions,
430
tu: TimeUnit,
431
tz: Option<Tz>,
432
time_type: &DataType,
433
) -> PolarsResult<(Column, GroupPositions)> {
434
let mut dt = dt.rechunk();
435
436
let groups = if group_by.is_none() {
437
// a requirement for the index
438
// so we can set this such that downstream code has this info
439
dt.set_sorted_flag(IsSorted::Ascending);
440
let dt = dt.datetime().unwrap();
441
let vals = dt.physical().downcast_iter().next().unwrap();
442
let ts = vals.values().as_slice();
443
PolarsResult::Ok(GroupsType::Slice {
444
groups: group_by_values(
445
options.period,
446
options.offset,
447
ts,
448
options.closed_window,
449
tu,
450
tz,
451
)?,
452
rolling: true,
453
})
454
} else {
455
let dt = dt.datetime().unwrap();
456
let vals = dt.physical().downcast_iter().next().unwrap();
457
let ts = vals.values().as_slice();
458
459
let groups = group_by.unwrap();
460
461
let iter = groups.into_par_iter().map(|[start, len]| {
462
let group_offset = start;
463
let start = start as usize;
464
let end = start + len as usize;
465
let values = &ts[start..end];
466
check_sortedness_slice(values)?;
467
468
let group = group_by_values(
469
options.period,
470
options.offset,
471
values,
472
options.closed_window,
473
tu,
474
tz,
475
)?;
476
477
PolarsResult::Ok(
478
group
479
.iter()
480
.map(|[start, len]| [*start + group_offset, *len])
481
.collect_vec(),
482
)
483
});
484
485
let groups = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
486
let groups = POOL.install(|| flatten_par(&groups));
487
PolarsResult::Ok(GroupsType::Slice {
488
groups,
489
rolling: true,
490
})
491
}?;
492
493
let dt = dt.cast(time_type).unwrap();
494
495
Ok((dt, groups.into_sliceable()))
496
}
497
}
498
499
#[cfg(test)]
500
mod test {
501
use polars_compute::rolling::QuantileMethod;
502
use polars_ops::prelude::*;
503
504
use super::*;
505
506
#[test]
507
fn test_rolling_group_by_tu() -> PolarsResult<()> {
508
// test multiple time units
509
for tu in [
510
TimeUnit::Nanoseconds,
511
TimeUnit::Microseconds,
512
TimeUnit::Milliseconds,
513
] {
514
let mut date = StringChunked::new(
515
"dt".into(),
516
[
517
"2020-01-01 13:45:48",
518
"2020-01-01 16:42:13",
519
"2020-01-01 16:45:09",
520
"2020-01-02 18:12:48",
521
"2020-01-03 19:45:32",
522
"2020-01-08 23:16:43",
523
],
524
)
525
.as_datetime(
526
None,
527
tu,
528
false,
529
false,
530
None,
531
&StringChunked::from_iter(std::iter::once("raise")),
532
)?
533
.into_column();
534
date.set_sorted_flag(IsSorted::Ascending);
535
let a = Column::new("a".into(), [3, 7, 5, 9, 2, 1]);
536
let df = DataFrame::new(vec![date, a.clone()])?;
537
538
let (_, groups) = df
539
.rolling(
540
None,
541
&RollingGroupOptions {
542
index_column: "dt".into(),
543
period: Duration::parse("2d"),
544
offset: Duration::parse("-2d"),
545
closed_window: ClosedWindow::Right,
546
},
547
)
548
.unwrap();
549
550
let sum = unsafe { a.agg_sum(&groups) };
551
let expected = Column::new("".into(), [3, 10, 15, 24, 11, 1]);
552
assert_eq!(sum, expected);
553
}
554
555
Ok(())
556
}
557
558
#[test]
559
fn test_rolling_group_by_aggs() -> PolarsResult<()> {
560
let mut date = StringChunked::new(
561
"dt".into(),
562
[
563
"2020-01-01 13:45:48",
564
"2020-01-01 16:42:13",
565
"2020-01-01 16:45:09",
566
"2020-01-02 18:12:48",
567
"2020-01-03 19:45:32",
568
"2020-01-08 23:16:43",
569
],
570
)
571
.as_datetime(
572
None,
573
TimeUnit::Milliseconds,
574
false,
575
false,
576
None,
577
&StringChunked::from_iter(std::iter::once("raise")),
578
)?
579
.into_column();
580
date.set_sorted_flag(IsSorted::Ascending);
581
582
let a = Column::new("a".into(), [3, 7, 5, 9, 2, 1]);
583
let df = DataFrame::new(vec![date, a.clone()])?;
584
585
let (_, groups) = df
586
.rolling(
587
None,
588
&RollingGroupOptions {
589
index_column: "dt".into(),
590
period: Duration::parse("2d"),
591
offset: Duration::parse("-2d"),
592
closed_window: ClosedWindow::Right,
593
},
594
)
595
.unwrap();
596
597
let nulls = Series::new(
598
"".into(),
599
[Some(3), Some(7), None, Some(9), Some(2), Some(1)],
600
);
601
602
let min = unsafe { a.as_materialized_series().agg_min(&groups) };
603
let expected = Series::new("".into(), [3, 3, 3, 3, 2, 1]);
604
assert_eq!(min, expected);
605
606
// Expected for nulls is equality.
607
let min = unsafe { nulls.agg_min(&groups) };
608
assert_eq!(min, expected);
609
610
let max = unsafe { a.as_materialized_series().agg_max(&groups) };
611
let expected = Series::new("".into(), [3, 7, 7, 9, 9, 1]);
612
assert_eq!(max, expected);
613
614
let max = unsafe { nulls.agg_max(&groups) };
615
assert_eq!(max, expected);
616
617
let var = unsafe { a.as_materialized_series().agg_var(&groups, 1) };
618
let expected = Series::new(
619
"".into(),
620
[0.0, 8.0, 4.000000000000002, 6.666666666666667, 24.5, 0.0],
621
);
622
assert!(abs(&(var - expected)?).unwrap().lt(1e-12).unwrap().all());
623
624
let var = unsafe { nulls.agg_var(&groups, 1) };
625
let expected = Series::new("".into(), [0.0, 8.0, 8.0, 9.333333333333343, 24.5, 0.0]);
626
assert!(abs(&(var - expected)?).unwrap().lt(1e-12).unwrap().all());
627
628
let quantile = unsafe {
629
a.as_materialized_series()
630
.agg_quantile(&groups, 0.5, QuantileMethod::Linear)
631
};
632
let expected = Series::new("".into(), [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]);
633
assert_eq!(quantile, expected);
634
635
let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileMethod::Linear) };
636
let expected = Series::new("".into(), [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]);
637
assert_eq!(quantile, expected);
638
639
Ok(())
640
}
641
}
642
643