Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs
6940 views
1
use polars_core::{with_match_physical_float_polars_type, with_match_physical_numeric_polars_type};
2
use polars_ops::series::SeriesMethods;
3
4
use super::*;
5
use crate::prelude::*;
6
use crate::series::AsSeries;
7
8
#[cfg(feature = "rolling_window")]
9
#[allow(clippy::type_complexity)]
10
fn rolling_agg<T>(
11
ca: &ChunkedArray<T>,
12
options: RollingOptionsFixedWindow,
13
rolling_agg_fn: &dyn Fn(
14
&[T::Native],
15
usize,
16
usize,
17
bool,
18
Option<&[f64]>,
19
Option<RollingFnParams>,
20
) -> PolarsResult<ArrayRef>,
21
rolling_agg_fn_nulls: &dyn Fn(
22
&PrimitiveArray<T::Native>,
23
usize,
24
usize,
25
bool,
26
Option<&[f64]>,
27
Option<RollingFnParams>,
28
) -> ArrayRef,
29
) -> PolarsResult<Series>
30
where
31
T: PolarsNumericType,
32
{
33
polars_ensure!(options.min_periods <= options.window_size, InvalidOperation: "`min_periods` should be <= `window_size`");
34
if ca.is_empty() {
35
return Ok(Series::new_empty(ca.name().clone(), ca.dtype()));
36
}
37
let ca = ca.rechunk();
38
39
let arr = ca.downcast_iter().next().unwrap();
40
let arr = match ca.null_count() {
41
0 => rolling_agg_fn(
42
arr.values().as_slice(),
43
options.window_size,
44
options.min_periods,
45
options.center,
46
options.weights.as_deref(),
47
options.fn_params,
48
)?,
49
_ => rolling_agg_fn_nulls(
50
arr,
51
options.window_size,
52
options.min_periods,
53
options.center,
54
options.weights.as_deref(),
55
options.fn_params,
56
),
57
};
58
Series::try_from((ca.name().clone(), arr))
59
}
60
61
#[cfg(feature = "rolling_window_by")]
62
#[allow(clippy::type_complexity)]
63
fn rolling_agg_by<T>(
64
ca: &ChunkedArray<T>,
65
by: &Series,
66
options: RollingOptionsDynamicWindow,
67
rolling_agg_fn_dynamic: &dyn Fn(
68
&[T::Native],
69
Duration,
70
&[i64],
71
ClosedWindow,
72
usize,
73
TimeUnit,
74
Option<&TimeZone>,
75
Option<RollingFnParams>,
76
Option<&[IdxSize]>,
77
) -> PolarsResult<ArrayRef>,
78
) -> PolarsResult<Series>
79
where
80
T: PolarsNumericType,
81
{
82
if ca.is_empty() {
83
return Ok(Series::new_empty(ca.name().clone(), ca.dtype()));
84
}
85
polars_ensure!(by.null_count() == 0 && ca.null_count() == 0, InvalidOperation: "'Expr.rolling_*_by(...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'");
86
polars_ensure!(ca.len() == by.len(), InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column");
87
ensure_duration_matches_dtype(options.window_size, by.dtype(), "window_size")?;
88
polars_ensure!(!options.window_size.is_zero() && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive");
89
let (by, tz) = match by.dtype() {
90
DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz),
91
DataType::Date => (
92
by.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,
93
&None,
94
),
95
DataType::Int64 => (
96
by.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
97
&None,
98
),
99
DataType::Int32 | DataType::UInt64 | DataType::UInt32 => (
100
by.cast(&DataType::Int64)?
101
.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?,
102
&None,
103
),
104
dt => polars_bail!(InvalidOperation:
105
"in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)",
106
dt,
107
"Date/Datetime/Int64/Int32/UInt64/UInt32"),
108
};
109
let ca = ca.rechunk();
110
let by = by.rechunk();
111
let by_is_sorted = by.is_sorted(SortOptions {
112
descending: false,
113
..Default::default()
114
})?;
115
let by = by.datetime().unwrap();
116
let tu = by.time_unit();
117
118
let func = rolling_agg_fn_dynamic;
119
let out: ArrayRef = if by_is_sorted {
120
let arr = ca.downcast_iter().next().unwrap();
121
let by_values = by.physical().cont_slice().unwrap();
122
let values = arr.values().as_slice();
123
func(
124
values,
125
options.window_size,
126
by_values,
127
options.closed_window,
128
options.min_periods,
129
tu,
130
tz.as_ref(),
131
options.fn_params,
132
None,
133
)?
134
} else {
135
let sorting_indices = by.physical().arg_sort(Default::default());
136
let ca = unsafe { ca.take_unchecked(&sorting_indices) };
137
let by = unsafe { by.physical().take_unchecked(&sorting_indices) };
138
let arr = ca.downcast_iter().next().unwrap();
139
let by_values = by.cont_slice().unwrap();
140
let values = arr.values().as_slice();
141
func(
142
values,
143
options.window_size,
144
by_values,
145
options.closed_window,
146
options.min_periods,
147
tu,
148
tz.as_ref(),
149
options.fn_params,
150
Some(sorting_indices.cont_slice().unwrap()),
151
)?
152
};
153
Series::try_from((ca.name().clone(), out))
154
}
155
156
pub trait SeriesOpsTime: AsSeries {
157
/// Apply a rolling mean to a Series based on another Series.
158
#[cfg(feature = "rolling_window_by")]
159
fn rolling_mean_by(
160
&self,
161
by: &Series,
162
options: RollingOptionsDynamicWindow,
163
) -> PolarsResult<Series> {
164
let s = self.as_series().to_float()?;
165
with_match_physical_float_polars_type!(s.dtype(), |$T| {
166
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
167
rolling_agg_by(
168
ca,
169
by,
170
options,
171
&super::rolling_kernels::no_nulls::rolling_mean,
172
)
173
})
174
}
175
/// Apply a rolling mean to a Series.
176
///
177
/// See: [`RollingAgg::rolling_mean`]
178
#[cfg(feature = "rolling_window")]
179
fn rolling_mean(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
180
let s = self.as_series().to_float()?;
181
with_match_physical_float_polars_type!(s.dtype(), |$T| {
182
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
183
rolling_agg(
184
ca,
185
options,
186
&rolling::no_nulls::rolling_mean,
187
&rolling::nulls::rolling_mean,
188
)
189
})
190
}
191
/// Apply a rolling sum to a Series based on another Series.
192
#[cfg(feature = "rolling_window_by")]
193
fn rolling_sum_by(
194
&self,
195
by: &Series,
196
options: RollingOptionsDynamicWindow,
197
) -> PolarsResult<Series> {
198
let mut s = self.as_series().clone();
199
if s.dtype() == &DataType::Boolean {
200
s = s.cast(&DataType::IDX_DTYPE).unwrap();
201
}
202
if matches!(
203
s.dtype(),
204
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16
205
) {
206
s = s.cast(&DataType::Int64).unwrap();
207
}
208
209
polars_ensure!(
210
s.dtype().is_primitive_numeric() && !s.dtype().is_unknown(),
211
op = "rolling_sum_by",
212
s.dtype()
213
);
214
215
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
216
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
217
rolling_agg_by(
218
ca,
219
by,
220
options,
221
&super::rolling_kernels::no_nulls::rolling_sum,
222
)
223
})
224
}
225
226
/// Apply a rolling sum to a Series.
227
#[cfg(feature = "rolling_window")]
228
fn rolling_sum(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
229
let mut s = self.as_series().clone();
230
if options.weights.is_some() {
231
s = s.to_float()?;
232
} else if s.dtype() == &DataType::Boolean {
233
s = s.cast(&DataType::IDX_DTYPE).unwrap();
234
} else if matches!(
235
s.dtype(),
236
DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16
237
) {
238
s = s.cast(&DataType::Int64).unwrap();
239
}
240
241
polars_ensure!(
242
s.dtype().is_primitive_numeric() && !s.dtype().is_unknown(),
243
op = "rolling_sum",
244
s.dtype()
245
);
246
247
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
248
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
249
rolling_agg(
250
ca,
251
options,
252
&rolling::no_nulls::rolling_sum,
253
&rolling::nulls::rolling_sum,
254
)
255
})
256
}
257
258
/// Apply a rolling quantile to a Series based on another Series.
259
#[cfg(feature = "rolling_window_by")]
260
fn rolling_quantile_by(
261
&self,
262
by: &Series,
263
options: RollingOptionsDynamicWindow,
264
) -> PolarsResult<Series> {
265
let s = self.as_series().to_float()?;
266
with_match_physical_float_polars_type!(s.dtype(), |$T| {
267
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
268
rolling_agg_by(
269
ca,
270
by,
271
options,
272
&super::rolling_kernels::no_nulls::rolling_quantile,
273
)
274
})
275
}
276
277
/// Apply a rolling quantile to a Series.
278
#[cfg(feature = "rolling_window")]
279
fn rolling_quantile(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
280
let s = self.as_series().to_float()?;
281
with_match_physical_float_polars_type!(s.dtype(), |$T| {
282
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
283
rolling_agg(
284
ca,
285
options,
286
&rolling::no_nulls::rolling_quantile,
287
&rolling::nulls::rolling_quantile,
288
)
289
})
290
}
291
292
/// Apply a rolling min to a Series based on another Series.
293
#[cfg(feature = "rolling_window_by")]
294
fn rolling_min_by(
295
&self,
296
by: &Series,
297
options: RollingOptionsDynamicWindow,
298
) -> PolarsResult<Series> {
299
let s = self.as_series().clone();
300
301
let dt = s.dtype();
302
match dt {
303
// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.
304
&DataType::Boolean => {
305
return s
306
.cast(&DataType::UInt8)?
307
.rolling_min_by(by, options)?
308
.cast(&DataType::Boolean);
309
},
310
dt if dt.is_temporal() => {
311
return s.to_physical_repr().rolling_min_by(by, options)?.cast(dt);
312
},
313
dt => {
314
polars_ensure!(
315
dt.is_primitive_numeric() && !dt.is_unknown(),
316
op = "rolling_min_by",
317
dt
318
);
319
},
320
}
321
322
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
323
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
324
rolling_agg_by(
325
ca,
326
by,
327
options,
328
&super::rolling_kernels::no_nulls::rolling_min,
329
)
330
})
331
}
332
333
/// Apply a rolling min to a Series.
334
#[cfg(feature = "rolling_window")]
335
fn rolling_min(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
336
let mut s = self.as_series().clone();
337
if options.weights.is_some() {
338
s = s.to_float()?;
339
}
340
341
let dt = s.dtype();
342
match dt {
343
// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.
344
&DataType::Boolean => {
345
return s
346
.cast(&DataType::UInt8)?
347
.rolling_min(options)?
348
.cast(&DataType::Boolean);
349
},
350
dt if dt.is_temporal() => {
351
return s.to_physical_repr().rolling_min(options)?.cast(dt);
352
},
353
dt => {
354
polars_ensure!(
355
dt.is_primitive_numeric() && !dt.is_unknown(),
356
op = "rolling_min",
357
dt
358
);
359
},
360
}
361
362
with_match_physical_numeric_polars_type!(dt, |$T| {
363
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
364
rolling_agg(
365
ca,
366
options,
367
&rolling::no_nulls::rolling_min,
368
&rolling::nulls::rolling_min,
369
)
370
})
371
}
372
373
/// Apply a rolling max to a Series based on another Series.
374
#[cfg(feature = "rolling_window_by")]
375
fn rolling_max_by(
376
&self,
377
by: &Series,
378
options: RollingOptionsDynamicWindow,
379
) -> PolarsResult<Series> {
380
let s = self.as_series().clone();
381
382
let dt = s.dtype();
383
match dt {
384
// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.
385
&DataType::Boolean => {
386
return s
387
.cast(&DataType::UInt8)?
388
.rolling_max_by(by, options)?
389
.cast(&DataType::Boolean);
390
},
391
dt if dt.is_temporal() => {
392
return s.to_physical_repr().rolling_max_by(by, options)?.cast(dt);
393
},
394
dt => {
395
polars_ensure!(
396
dt.is_primitive_numeric() && !dt.is_unknown(),
397
op = "rolling_max_by",
398
dt
399
);
400
},
401
}
402
403
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
404
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
405
rolling_agg_by(
406
ca,
407
by,
408
options,
409
&super::rolling_kernels::no_nulls::rolling_max,
410
)
411
})
412
}
413
414
/// Apply a rolling max to a Series.
415
#[cfg(feature = "rolling_window")]
416
fn rolling_max(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
417
let mut s = self.as_series().clone();
418
if options.weights.is_some() {
419
s = s.to_float()?;
420
}
421
422
let dt = s.dtype();
423
match dt {
424
// Our rolling kernels don't yet support boolean, use UInt8 as a workaround for now.
425
&DataType::Boolean => {
426
return s
427
.cast(&DataType::UInt8)?
428
.rolling_max(options)?
429
.cast(&DataType::Boolean);
430
},
431
dt if dt.is_temporal() => {
432
return s.to_physical_repr().rolling_max(options)?.cast(dt);
433
},
434
dt => {
435
polars_ensure!(
436
dt.is_primitive_numeric() && !dt.is_unknown(),
437
op = "rolling_max",
438
dt
439
);
440
},
441
}
442
443
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
444
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
445
rolling_agg(
446
ca,
447
options,
448
&rolling::no_nulls::rolling_max,
449
&rolling::nulls::rolling_max,
450
)
451
})
452
}
453
454
/// Apply a rolling variance to a Series based on another Series.
455
#[cfg(feature = "rolling_window_by")]
456
fn rolling_var_by(
457
&self,
458
by: &Series,
459
options: RollingOptionsDynamicWindow,
460
) -> PolarsResult<Series> {
461
let s = self.as_series().to_float()?;
462
463
with_match_physical_float_polars_type!(s.dtype(), |$T| {
464
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
465
let mut ca = ca.clone();
466
467
rolling_agg_by(
468
&ca,
469
by,
470
options,
471
&super::rolling_kernels::no_nulls::rolling_var,
472
)
473
})
474
}
475
476
/// Apply a rolling variance to a Series.
477
#[cfg(feature = "rolling_window")]
478
fn rolling_var(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
479
let s = self.as_series().to_float()?;
480
481
with_match_physical_float_polars_type!(s.dtype(), |$T| {
482
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
483
let mut ca = ca.clone();
484
485
rolling_agg(
486
&ca,
487
options,
488
&rolling::no_nulls::rolling_var,
489
&rolling::nulls::rolling_var,
490
)
491
})
492
}
493
494
/// Apply a rolling std_dev to a Series based on another Series.
495
#[cfg(feature = "rolling_window_by")]
496
fn rolling_std_by(
497
&self,
498
by: &Series,
499
options: RollingOptionsDynamicWindow,
500
) -> PolarsResult<Series> {
501
self.rolling_var_by(by, options).map(|mut s| {
502
match s.dtype().clone() {
503
DataType::Float32 => {
504
let ca: &mut ChunkedArray<Float32Type> = s._get_inner_mut().as_mut();
505
ca.apply_mut(|v| v.powf(0.5))
506
},
507
DataType::Float64 => {
508
let ca: &mut ChunkedArray<Float64Type> = s._get_inner_mut().as_mut();
509
ca.apply_mut(|v| v.powf(0.5))
510
},
511
_ => unreachable!(),
512
}
513
s
514
})
515
}
516
517
/// Apply a rolling std_dev to a Series.
518
#[cfg(feature = "rolling_window")]
519
fn rolling_std(&self, options: RollingOptionsFixedWindow) -> PolarsResult<Series> {
520
self.rolling_var(options).map(|mut s| {
521
match s.dtype().clone() {
522
DataType::Float32 => {
523
let ca: &mut ChunkedArray<Float32Type> = s._get_inner_mut().as_mut();
524
ca.apply_mut(|v| v.powf(0.5))
525
},
526
DataType::Float64 => {
527
let ca: &mut ChunkedArray<Float64Type> = s._get_inner_mut().as_mut();
528
ca.apply_mut(|v| v.powf(0.5))
529
},
530
_ => unreachable!(),
531
}
532
s
533
})
534
}
535
}
536
537
impl SeriesOpsTime for Series {}
538
539