Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs
8406 views
1
use std::borrow::Cow;
2
3
use arrow::types::NativeType;
4
#[cfg(feature = "dtype-f16")]
5
use num_traits::real::Real;
6
use polars_compute::rolling::no_nulls::RollingAggWindowNoNulls;
7
use polars_compute::rolling::nulls::RollingAggWindowNulls;
8
use polars_compute::rolling::{MeanWindow, SumWindow, no_nulls, nulls};
9
use polars_core::{with_match_physical_float_polars_type, with_match_physical_numeric_polars_type};
10
use polars_ops::series::SeriesMethods;
11
use polars_utils::float::IsFloat;
12
13
use super::*;
14
use crate::prelude::*;
15
use crate::series::AsSeries;
16
17
#[cfg(feature = "rolling_window")]
18
#[allow(clippy::type_complexity)]
19
fn rolling_agg<T>(
20
ca: &ChunkedArray<T>,
21
options: RollingOptionsFixedWindow,
22
rolling_agg_fn: &dyn Fn(
23
&[T::Native],
24
usize,
25
usize,
26
bool,
27
Option<&[f64]>,
28
Option<RollingFnParams>,
29
) -> PolarsResult<ArrayRef>,
30
rolling_agg_fn_nulls: &dyn Fn(
31
&PrimitiveArray<T::Native>,
32
usize,
33
usize,
34
bool,
35
Option<&[f64]>,
36
Option<RollingFnParams>,
37
) -> ArrayRef,
38
) -> PolarsResult<Series>
39
where
40
T: PolarsNumericType,
41
{
42
polars_ensure!(options.min_periods <= options.window_size, InvalidOperation: "`min_periods` should be <= `window_size`");
43
if ca.is_empty() {
44
return Ok(Series::new_empty(ca.name().clone(), ca.dtype()));
45
}
46
let ca = ca.rechunk();
47
48
let arr = ca.downcast_iter().next().unwrap();
49
let arr = match ca.null_count() {
50
0 => rolling_agg_fn(
51
arr.values().as_slice(),
52
options.window_size,
53
options.min_periods,
54
options.center,
55
options.weights.as_deref(),
56
options.fn_params,
57
)?,
58
_ => rolling_agg_fn_nulls(
59
arr,
60
options.window_size,
61
options.min_periods,
62
options.center,
63
options.weights.as_deref(),
64
options.fn_params,
65
),
66
};
67
Series::try_from((ca.name().clone(), arr))
68
}
69
70
#[cfg(feature = "rolling_window_by")]
71
fn rolling_agg_by<T, Out, NoNullsAgg, NullsAgg>(
72
ca: &ChunkedArray<T>,
73
by: &Series,
74
options: RollingOptionsDynamicWindow,
75
) -> PolarsResult<Series>
76
where
77
T: PolarsNumericType,
78
T::Native: NativeType + IsFloat,
79
Out: NativeType,
80
NoNullsAgg: RollingAggWindowNoNulls<T::Native, Out>,
81
NullsAgg: RollingAggWindowNulls<T::Native, Out>,
82
{
83
use crate::chunkedarray::rolling_window::rolling_kernels::shared::{
84
RollingAggWindowNoNullsWrapper, RollingAggWindowNullsWrapper, rolling_apply_agg,
85
};
86
87
if ca.is_empty() {
88
return Ok(Series::new_empty(ca.name().clone(), ca.dtype()));
89
}
90
91
polars_ensure!(
92
ca.len() == by.len(),
93
InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column"
94
);
95
ensure_duration_matches_dtype(options.window_size, by.dtype(), "window_size")?;
96
polars_ensure!(
97
!options.window_size.is_zero() && !options.window_size.negative,
98
InvalidOperation: "`window_size` must be strictly positive"
99
);
100
101
let (by, tz) = match by.dtype() {
102
DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz),
103
DataType::Date => (
104
by.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,
105
&None,
106
),
107
DataType::Int64 => (
108
by.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
109
&None,
110
),
111
DataType::Int32 | DataType::UInt64 | DataType::UInt32 => (
112
by.cast(&DataType::Int64)?
113
.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
114
&None,
115
),
116
dt => polars_bail!(InvalidOperation:
117
"in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)",
118
dt,
119
"Date/Datetime/Int64/Int32/UInt64/UInt32"),
120
};
121
let mut ca_rechunked = ca.rechunk();
122
let by = by.rechunk();
123
let by_is_sorted = by.is_sorted(SortOptions {
124
descending: false,
125
..Default::default()
126
})?;
127
let by_logical = by.datetime().unwrap();
128
let tu = by_logical.time_unit();
129
let mut by_physical = Cow::Borrowed(by_logical.physical());
130
let sorting_indices_opt = (!by_is_sorted).then(|| by_physical.arg_sort(Default::default()));
131
132
if let Some(sorting_indices) = &sorting_indices_opt {
133
// SAFETY: `sorting_indices` is in-bounds because we checked that `ca.len() == by.len()` and
134
// they are derived from `by`.
135
ca_rechunked = Cow::Owned(unsafe { ca_rechunked.take_unchecked(sorting_indices) });
136
// SAFETY: `sorting_indices` is in-bounds because they are derived from `by`.
137
by_physical = Cow::Owned(unsafe { by_physical.take_unchecked(sorting_indices) });
138
}
139
140
let by_values = by_physical.cont_slice().unwrap();
141
let arr = ca_rechunked.downcast_iter().next().unwrap();
142
let values = arr.values().as_slice();
143
144
// We explicitly branch here because we want to compile different versions based on the no_nulls
145
// or nulls kernel.
146
let out: ArrayRef = if ca.null_count() == 0 {
147
let mut agg_window =
148
RollingAggWindowNoNullsWrapper(NoNullsAgg::new(values, 0, 0, options.fn_params, None));
149
150
rolling_apply_agg(
151
&mut agg_window,
152
options.window_size,
153
by_values,
154
options.closed_window,
155
options.min_periods,
156
tu,
157
tz.as_ref(),
158
sorting_indices_opt
159
.as_ref()
160
.map(|s| s.cont_slice().unwrap()),
161
)?
162
} else {
163
let validity = arr.validity().unwrap();
164
let mut agg_window = RollingAggWindowNullsWrapper(NullsAgg::new(
165
values,
166
validity,
167
0,
168
0,
169
options.fn_params,
170
None,
171
));
172
173
rolling_apply_agg(
174
&mut agg_window,
175
options.window_size,
176
by_values,
177
options.closed_window,
178
options.min_periods,
179
tu,
180
tz.as_ref(),
181
sorting_indices_opt
182
.as_ref()
183
.map(|s| s.cont_slice().unwrap()),
184
)?
185
};
186
187
Series::try_from((ca.name().clone(), out))
188
}
189
190
pub trait SeriesOpsTime: AsSeries {
191
/// Apply a rolling mean to a Series based on another Series.
192
#[cfg(feature = "rolling_window_by")]
193
fn rolling_mean_by(
194
&self,
195
by: &Series,
196
options: RollingOptionsDynamicWindow,
197
) -> PolarsResult<Series> {
198
let s = self.as_series().to_float()?;
199
with_match_physical_float_polars_type!(s.dtype(), |$T| {
200
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
201
rolling_agg_by::<$T, _, MeanWindow<_>, MeanWindow<_>>(ca, by, options)
202
})
203
}
204
/// Apply a rolling mean to a Series.
205
///
206
/// See: [`RollingAgg::rolling_mean`]
207
#[cfg(feature = "rolling_window")]
208
fn rolling_mean(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
209
let s = self.as_series().to_float()?;
210
with_match_physical_float_polars_type!(s.dtype(), |$T| {
211
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
212
rolling_agg(
213
ca,
214
options,
215
&rolling::no_nulls::rolling_mean,
216
&rolling::nulls::rolling_mean,
217
)
218
})
219
}
220
/// Apply a rolling sum to a Series based on another Series.
221
#[cfg(feature = "rolling_window_by")]
222
fn rolling_sum_by(
223
&self,
224
by: &Series,
225
options: RollingOptionsDynamicWindow,
226
) -> PolarsResult<Series> {
227
let mut s = self.as_series().clone();
228
if s.dtype() == &DataType::Boolean {
229
s = s.cast(&DataType::IDX_DTYPE).unwrap();
230
}
231
if matches!(
232
s.dtype(),
233
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16
234
) {
235
s = s.cast(&DataType::Int64).unwrap();
236
}
237
238
polars_ensure!(
239
s.dtype().is_primitive_numeric() && !s.dtype().is_unknown(),
240
op = "rolling_sum_by",
241
s.dtype()
242
);
243
244
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
245
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
246
type Native = <$T as PolarsNumericType>::Native;
247
type SM<'a> = SumWindow<'a, Native, Native>;
248
rolling_agg_by::<$T, _, SM, SM>(ca, by, options)
249
})
250
}
251
252
/// Apply a rolling sum to a Series.
253
#[cfg(feature = "rolling_window")]
254
fn rolling_sum(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
255
let mut s = self.as_series().clone();
256
if options.weights.is_some() {
257
s = s.to_float()?;
258
} else if s.dtype() == &DataType::Boolean {
259
s = s.cast(&DataType::IDX_DTYPE).unwrap();
260
} else if matches!(
261
s.dtype(),
262
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16
263
) {
264
s = s.cast(&DataType::Int64).unwrap();
265
}
266
267
polars_ensure!(
268
s.dtype().is_primitive_numeric() && !s.dtype().is_unknown(),
269
op = "rolling_sum",
270
s.dtype()
271
);
272
273
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
274
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
275
rolling_agg(
276
ca,
277
options,
278
&rolling::no_nulls::rolling_sum,
279
&rolling::nulls::rolling_sum,
280
)
281
})
282
}
283
284
/// Apply a rolling quantile to a Series based on another Series.
285
#[cfg(feature = "rolling_window_by")]
286
fn rolling_quantile_by(
287
&self,
288
by: &Series,
289
options: RollingOptionsDynamicWindow,
290
) -> PolarsResult<Series> {
291
let s = self.as_series().to_float()?;
292
with_match_physical_float_polars_type!(s.dtype(), |$T| {
293
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
294
rolling_agg_by::<
295
$T,
296
_,
297
no_nulls::QuantileWindow<_>,
298
nulls::QuantileWindow<_>
299
>(ca, by, options)
300
})
301
}
302
303
/// Apply a rolling quantile to a Series.
304
#[cfg(feature = "rolling_window")]
305
fn rolling_quantile(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
306
let s = self.as_series().to_float()?;
307
with_match_physical_float_polars_type!(s.dtype(), |$T| {
308
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
309
rolling_agg(
310
ca,
311
options,
312
&rolling::no_nulls::rolling_quantile,
313
&rolling::nulls::rolling_quantile,
314
)
315
})
316
}
317
318
/// Apply a rolling min to a Series based on another Series.
319
#[cfg(feature = "rolling_window_by")]
320
fn rolling_min_by(
321
&self,
322
by: &Series,
323
options: RollingOptionsDynamicWindow,
324
) -> PolarsResult<Series> {
325
let s = self.as_series().clone();
326
327
let dt = s.dtype();
328
match dt {
329
// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.
330
&DataType::Boolean => {
331
return s
332
.cast(&DataType::UInt8)?
333
.rolling_min_by(by, options)?
334
.cast(&DataType::Boolean);
335
},
336
dt if dt.is_temporal() => {
337
return s.to_physical_repr().rolling_min_by(by, options)?.cast(dt);
338
},
339
dt => {
340
polars_ensure!(
341
dt.is_primitive_numeric() && !dt.is_unknown(),
342
op = "rolling_min_by",
343
dt
344
);
345
},
346
}
347
348
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
349
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
350
rolling_agg_by::<
351
$T,
352
_,
353
no_nulls::MinWindow<_>,
354
nulls::MinWindow<_>
355
>(ca, by, options)
356
})
357
}
358
359
/// Apply a rolling min to a Series.
360
#[cfg(feature = "rolling_window")]
361
fn rolling_min(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
362
let mut s = self.as_series().clone();
363
if options.weights.is_some() {
364
s = s.to_float()?;
365
}
366
367
let dt = s.dtype();
368
match dt {
369
// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.
370
&DataType::Boolean => {
371
return s
372
.cast(&DataType::UInt8)?
373
.rolling_min(options)?
374
.cast(&DataType::Boolean);
375
},
376
dt if dt.is_temporal() => {
377
return s.to_physical_repr().rolling_min(options)?.cast(dt);
378
},
379
dt => {
380
polars_ensure!(
381
dt.is_primitive_numeric() && !dt.is_unknown(),
382
op = "rolling_min",
383
dt
384
);
385
},
386
}
387
388
with_match_physical_numeric_polars_type!(dt, |$T| {
389
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
390
rolling_agg(
391
ca,
392
options,
393
&rolling::no_nulls::rolling_min,
394
&rolling::nulls::rolling_min,
395
)
396
})
397
}
398
399
/// Apply a rolling max to a Series based on another Series.
400
#[cfg(feature = "rolling_window_by")]
401
fn rolling_max_by(
402
&self,
403
by: &Series,
404
options: RollingOptionsDynamicWindow,
405
) -> PolarsResult<Series> {
406
let s = self.as_series().clone();
407
408
let dt = s.dtype();
409
match dt {
410
// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.
411
&DataType::Boolean => {
412
return s
413
.cast(&DataType::UInt8)?
414
.rolling_max_by(by, options)?
415
.cast(&DataType::Boolean);
416
},
417
dt if dt.is_temporal() => {
418
return s.to_physical_repr().rolling_max_by(by, options)?.cast(dt);
419
},
420
dt => {
421
polars_ensure!(
422
dt.is_primitive_numeric() && !dt.is_unknown(),
423
op = "rolling_max_by",
424
dt
425
);
426
},
427
}
428
429
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
430
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
431
rolling_agg_by::<
432
$T,
433
_,
434
no_nulls::MaxWindow<_>,
435
nulls::MaxWindow<_>
436
>(ca, by, options)
437
})
438
}
439
440
/// Apply a rolling max to a Series.
441
#[cfg(feature = "rolling_window")]
442
fn rolling_max(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
443
let mut s = self.as_series().clone();
444
if options.weights.is_some() {
445
s = s.to_float()?;
446
}
447
448
let dt = s.dtype();
449
match dt {
450
// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.
451
&DataType::Boolean => {
452
return s
453
.cast(&DataType::UInt8)?
454
.rolling_max(options)?
455
.cast(&DataType::Boolean);
456
},
457
dt if dt.is_temporal() => {
458
return s.to_physical_repr().rolling_max(options)?.cast(dt);
459
},
460
dt => {
461
polars_ensure!(
462
dt.is_primitive_numeric() && !dt.is_unknown(),
463
op = "rolling_max",
464
dt
465
);
466
},
467
}
468
469
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
470
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
471
rolling_agg(
472
ca,
473
options,
474
&rolling::no_nulls::rolling_max,
475
&rolling::nulls::rolling_max,
476
)
477
})
478
}
479
480
/// Apply a rolling variance to a Series based on another Series.
481
#[cfg(feature = "rolling_window_by")]
482
fn rolling_var_by(
483
&self,
484
by: &Series,
485
options: RollingOptionsDynamicWindow,
486
) -> PolarsResult<Series> {
487
let s = self.as_series().to_float()?;
488
489
with_match_physical_float_polars_type!(s.dtype(), |$T| {
490
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
491
492
rolling_agg_by::<
493
$T,
494
_,
495
no_nulls::MomentWindow<_, no_nulls::VarianceMoment>,
496
nulls::MomentWindow<_, nulls::VarianceMoment>
497
>(ca, by, options)
498
})
499
}
500
501
/// Apply a rolling variance to a Series.
502
#[cfg(feature = "rolling_window")]
503
fn rolling_var(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
504
let s = self.as_series().to_float()?;
505
506
with_match_physical_float_polars_type!(s.dtype(), |$T| {
507
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
508
509
rolling_agg(
510
ca,
511
options,
512
&rolling::no_nulls::rolling_var,
513
&rolling::nulls::rolling_var,
514
)
515
})
516
}
517
518
/// Apply a rolling std_dev to a Series based on another Series.
519
#[cfg(feature = "rolling_window_by")]
520
fn rolling_std_by(
521
&self,
522
by: &Series,
523
options: RollingOptionsDynamicWindow,
524
) -> PolarsResult<Series> {
525
self.rolling_var_by(by, options).map(|mut s| {
526
with_match_physical_float_polars_type!(s.dtype(), |$T| {
527
let ca: &mut ChunkedArray<$T> = s._get_inner_mut().as_mut();
528
ca.apply_mut(|v| v.sqrt());
529
});
530
531
s
532
})
533
}
534
535
/// Apply a rolling std_dev to a Series.
536
#[cfg(feature = "rolling_window")]
537
fn rolling_std(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
538
self.rolling_var(options).map(|mut s| {
539
with_match_physical_float_polars_type!(s.dtype(), |$T| {
540
let ca: &mut ChunkedArray<$T> = s._get_inner_mut().as_mut();
541
ca.apply_mut(|v| v.sqrt());
542
});
543
544
s
545
})
546
}
547
548
/// Apply a rolling rank to a Series based on another Series.
549
#[cfg(feature = "rolling_window_by")]
550
fn rolling_rank_by(
551
&self,
552
by: &Series,
553
options: RollingOptionsDynamicWindow,
554
) -> PolarsResult<Series> {
555
if !matches!(
556
options.closed_window,
557
ClosedWindow::Right | ClosedWindow::Both
558
) {
559
polars_bail!(InvalidOperation: "`rolling_rank_by` window needs to be closed on the right side (i.e., `closed` must be `right` or `both`)");
560
}
561
562
let s = self.as_series().clone();
563
564
match s.dtype() {
565
DataType::Boolean => return s.cast(&DataType::UInt8)?.rolling_rank_by(by, options),
566
dt if dt.is_temporal() => return s.to_physical_repr().rolling_rank_by(by, options),
567
dt => {
568
polars_ensure!(
569
dt.is_primitive_numeric() && !dt.is_unknown(),
570
op = "rolling_rank_by",
571
dt
572
);
573
},
574
}
575
576
let method = if let Some(RollingFnParams::Rank { method, .. }) = options.fn_params {
577
method
578
} else {
579
unreachable!("expected RollingFnParams::Rank");
580
};
581
582
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
583
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
584
585
match method {
586
RollingRankMethod::Average => rolling_agg_by::<
587
$T,
588
_,
589
no_nulls::RankWindowAvg<_>,
590
nulls::RankWindowAvg<_>
591
>(ca, by, options),
592
RollingRankMethod::Min => rolling_agg_by::<
593
$T,
594
_,
595
no_nulls::RankWindowMin<_>,
596
nulls::RankWindowMin<_>
597
>(ca, by, options),
598
RollingRankMethod::Max => rolling_agg_by::<
599
$T,
600
_,
601
no_nulls::RankWindowMax<_>,
602
nulls::RankWindowMax<_>
603
>(ca, by, options),
604
RollingRankMethod::Dense => rolling_agg_by::<
605
$T,
606
_,
607
no_nulls::RankWindowDense<_>,
608
nulls::RankWindowDense<_>
609
>(ca, by, options),
610
RollingRankMethod::Random => rolling_agg_by::<
611
$T,
612
_,
613
no_nulls::RankWindowRandom<_>,
614
nulls::RankWindowRandom<_>
615
>(ca, by, options),
616
_ => todo!()
617
}
618
})
619
}
620
621
/// Apply a rolling rank to a Series.
622
#[cfg(feature = "rolling_window")]
623
fn rolling_rank(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
624
let s = self.as_series();
625
626
match s.dtype() {
627
DataType::Boolean => return s.cast(&DataType::UInt8)?.rolling_rank(options),
628
dt if dt.is_temporal() => return s.to_physical_repr().rolling_rank(options),
629
dt => {
630
polars_ensure!(
631
dt.is_primitive_numeric() && !dt.is_unknown(),
632
op = "rolling_rank",
633
dt
634
);
635
},
636
}
637
638
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
639
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
640
let mut ca = ca.clone();
641
642
rolling_agg(
643
&ca,
644
options,
645
&rolling::no_nulls::rolling_rank,
646
&rolling::nulls::rolling_rank,
647
)
648
})
649
}
650
}
651
652
impl SeriesOpsTime for Series {}
653
654