Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/mod.rs
6939 views
1
mod mean;
2
mod min_max;
3
mod moment;
4
mod quantile;
5
mod sum;
6
use std::fmt::Debug;
7
8
use arrow::array::PrimitiveArray;
9
use arrow::datatypes::ArrowDataType;
10
use arrow::legacy::error::PolarsResult;
11
use arrow::legacy::utils::CustomIterTools;
12
use arrow::types::NativeType;
13
pub use mean::*;
14
pub use min_max::*;
15
pub use moment::*;
16
use num_traits::{Float, Num, NumCast};
17
pub use quantile::*;
18
pub use sum::*;
19
20
use super::*;
21
22
pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
23
fn new(
24
slice: &'a [T],
25
start: usize,
26
end: usize,
27
params: Option<RollingFnParams>,
28
window_size: Option<usize>,
29
) -> Self;
30
31
/// Update and recompute the window
32
///
33
/// # Safety
34
/// `start` and `end` must be within the windows bounds
35
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
36
}
37
38
// Use an aggregation window that maintains the state
39
pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
40
values: &'a [T],
41
window_size: usize,
42
min_periods: usize,
43
det_offsets_fn: Fo,
44
params: Option<RollingFnParams>,
45
) -> PolarsResult<ArrayRef>
46
where
47
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
48
Agg: RollingAggWindowNoNulls<'a, T>,
49
T: Debug + NativeType + Num,
50
{
51
let len = values.len();
52
let (start, end) = det_offsets_fn(0, window_size, len);
53
let mut agg_window = Agg::new(values, start, end, params, Some(window_size));
54
if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {
55
if validity.iter().all(|x| !x) {
56
return Ok(Box::new(PrimitiveArray::<T>::new_null(
57
T::PRIMITIVE.into(),
58
len,
59
)));
60
}
61
}
62
63
let out = (0..len).map(|idx| {
64
let (start, end) = det_offsets_fn(idx, window_size, len);
65
if end - start < min_periods {
66
None
67
} else {
68
// SAFETY:
69
// we are in bounds
70
unsafe { agg_window.update(start, end) }
71
}
72
});
73
let arr = PrimitiveArray::from_trusted_len_iter(out);
74
Ok(Box::new(arr))
75
}
76
77
pub(super) fn rolling_apply_weights<T, Fo, Fa>(
78
values: &[T],
79
window_size: usize,
80
min_periods: usize,
81
det_offsets_fn: Fo,
82
aggregator: Fa,
83
weights: &[T],
84
centered: bool,
85
) -> PolarsResult<ArrayRef>
86
where
87
T: NativeType + num_traits::Zero + std::ops::Div<Output = T> + Copy,
88
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
89
Fa: Fn(&[T], &[T]) -> T,
90
{
91
assert_eq!(weights.len(), window_size);
92
let len = values.len();
93
let out = (0..len)
94
.map(|idx| {
95
let (start, end) = det_offsets_fn(idx, window_size, len);
96
let vals = unsafe { values.get_unchecked(start..end) };
97
let win_len = end - start;
98
let weights_start = if centered {
99
// When using centered weights, we need to find the right location
100
// in the weights array specifically by aligning the center of the
101
// window with idx, to handle cases where the window is smaller than
102
// weights array.
103
let center = (window_size / 2) as isize;
104
let offset = center - (idx as isize - start as isize);
105
offset.max(0) as usize
106
} else if start == 0 {
107
// When start is 0, we need to work backwards from the end of the
108
// weights array to ensure we are lined up correctly (since the
109
// start of the values array is implicitly cut off)
110
weights.len() - win_len
111
} else {
112
0
113
};
114
let weights_slice = &weights[weights_start..weights_start + win_len];
115
aggregator(vals, weights_slice)
116
})
117
.collect_trusted::<Vec<T>>();
118
119
let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
120
Ok(Box::new(PrimitiveArray::new(
121
ArrowDataType::from(T::PRIMITIVE),
122
out.into(),
123
validity.map(|b| b.into()),
124
)))
125
}
126
127
fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
128
where
129
T: Float + std::ops::AddAssign,
130
{
131
// Compute weighted mean and weighted sum of squares in a single pass
132
let (wssq, wmean, total_weight) = vals.iter().zip(weights).fold(
133
(T::zero(), T::zero(), T::zero()),
134
|(wssq, wsum, wtot), (&v, &w)| (wssq + v * v * w, wsum + v * w, wtot + w),
135
);
136
if total_weight.is_zero() {
137
panic!("Weighted variance is undefined if weights sum to 0");
138
}
139
let mean = wmean / total_weight;
140
(wssq / total_weight) - (mean * mean)
141
}
142
143
pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
144
where
145
T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
146
{
147
values.iter().zip(weights).map(|(v, w)| *v * *w).sum()
148
}
149
150
/// Compute the weighted mean of values, given weights (not necessarily normalized).
151
/// Returns sum_i(values[i] * weights[i]) / sum_i(weights[i])
152
pub(crate) fn compute_mean_weights<T>(values: &[T], weights: &[T]) -> T
153
where
154
T: std::iter::Sum<T>
155
+ Copy
156
+ std::ops::Mul<Output = T>
157
+ std::ops::Div<Output = T>
158
+ num_traits::Zero,
159
{
160
let (weighted_sum, total_weight) = values
161
.iter()
162
.zip(weights)
163
.fold((T::zero(), T::zero()), |(wsum, wtot), (&v, &w)| {
164
(wsum + v * w, wtot + w)
165
});
166
if total_weight.is_zero() {
167
panic!("Weighted mean is undefined if weights sum to 0");
168
}
169
weighted_sum / total_weight
170
}
171
172
pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>
173
where
174
{
175
weights
176
.iter()
177
.map(|v| NumCast::from(*v).unwrap())
178
.collect::<Vec<_>>()
179
}
180
181