Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/min_max.rs
6939 views
1
use polars_utils::min_max::{MaxPropagateNan, MinMaxPolicy, MinPropagateNan};
2
3
use super::super::min_max::MinMaxWindow;
4
use super::*;
5
6
pub type MinWindow<'a, T> = MinMaxWindow<'a, T, MinPropagateNan>;
7
pub type MaxWindow<'a, T> = MinMaxWindow<'a, T, MaxPropagateNan>;
8
9
fn weighted_min_max<T, P>(values: &[T], weights: &[T]) -> T
10
where
11
T: NativeType + std::ops::Mul<Output = T>,
12
P: MinMaxPolicy,
13
{
14
values
15
.iter()
16
.zip(weights)
17
.map(|(v, w)| *v * *w)
18
.reduce(P::best)
19
.unwrap()
20
}
21
22
macro_rules! rolling_minmax_func {
23
($rolling_m:ident, $policy:ident) => {
24
pub fn $rolling_m<T>(
25
values: &[T],
26
window_size: usize,
27
min_periods: usize,
28
center: bool,
29
weights: Option<&[f64]>,
30
_params: Option<RollingFnParams>,
31
) -> PolarsResult<ArrayRef>
32
where
33
T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
34
{
35
let offset_fn = match center {
36
true => det_offsets_center,
37
false => det_offsets,
38
};
39
match weights {
40
None => rolling_apply_agg_window::<MinMaxWindow<T, $policy>, _, _>(
41
values,
42
window_size,
43
min_periods,
44
offset_fn,
45
None,
46
),
47
Some(weights) => {
48
assert!(
49
T::is_float(),
50
"implementation error, should only be reachable by float types"
51
);
52
let weights = weights
53
.iter()
54
.map(|v| NumCast::from(*v).unwrap())
55
.collect::<Vec<_>>();
56
no_nulls::rolling_apply_weights(
57
values,
58
window_size,
59
min_periods,
60
offset_fn,
61
weighted_min_max::<T, $policy>,
62
&weights,
63
center,
64
)
65
},
66
}
67
}
68
};
69
}
70
71
rolling_minmax_func!(rolling_min, MinPropagateNan);
72
rolling_minmax_func!(rolling_max, MaxPropagateNan);
73
74
#[cfg(test)]
75
mod test {
76
use super::*;
77
78
#[test]
79
fn test_rolling_min_max() {
80
let values = &[1.0f64, 5.0, 3.0, 4.0];
81
82
let out = rolling_min(values, 2, 2, false, None, None).unwrap();
83
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
84
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
85
assert_eq!(out, &[None, Some(1.0), Some(3.0), Some(3.0)]);
86
let out = rolling_max(values, 2, 2, false, None, None).unwrap();
87
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
88
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
89
assert_eq!(out, &[None, Some(5.0), Some(5.0), Some(4.0)]);
90
91
let out = rolling_min(values, 2, 1, false, None, None).unwrap();
92
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
93
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
94
assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.0)]);
95
let out = rolling_max(values, 2, 1, false, None, None).unwrap();
96
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
97
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
98
assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(4.0)]);
99
100
let out = rolling_max(values, 3, 1, false, None, None).unwrap();
101
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
102
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
103
assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(5.0)]);
104
105
// test nan handling.
106
let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
107
let out = rolling_min(values, 3, 3, false, None, None).unwrap();
108
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
109
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
110
// we cannot compare nans, so we compare the string values
111
assert_eq!(
112
format!("{:?}", out.as_slice()),
113
format!(
114
"{:?}",
115
&[
116
None,
117
None,
118
Some(1.0),
119
Some(f64::nan()),
120
Some(f64::nan()),
121
Some(f64::nan()),
122
Some(5.0)
123
]
124
)
125
);
126
127
let out = rolling_max(values, 3, 3, false, None, None).unwrap();
128
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
129
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
130
assert_eq!(
131
format!("{:?}", out.as_slice()),
132
format!(
133
"{:?}",
134
&[
135
None,
136
None,
137
Some(3.0),
138
Some(f64::nan()),
139
Some(f64::nan()),
140
Some(f64::nan()),
141
Some(7.0)
142
]
143
)
144
);
145
}
146
}
147
148