Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/mod.rs
8424 views
1
use std::fmt::Debug;
2
3
use arrow::array::PrimitiveArray;
4
use arrow::datatypes::ArrowDataType;
5
use arrow::legacy::error::PolarsResult;
6
use arrow::legacy::utils::CustomIterTools;
7
use arrow::types::NativeType;
8
use num_traits::{Float, Num, NumCast};
9
10
mod mean;
11
mod min_max;
12
mod moment;
13
mod quantile;
14
pub mod rank;
15
mod sum;
16
17
pub use mean::*;
18
pub use min_max::*;
19
pub use moment::*;
20
pub use quantile::*;
21
pub use rank::*;
22
pub use sum::*;
23
24
use super::*;
25
26
pub trait RollingAggWindowNoNulls<T: NativeType, Out: NativeType = T> {
27
type This<'a>: RollingAggWindowNoNulls<T, Out>;
28
29
fn new(
30
slice: &[T],
31
start: usize,
32
end: usize,
33
params: Option<RollingFnParams>,
34
window_size: Option<usize>,
35
) -> Self::This<'_>;
36
37
/// Update and recompute the window
38
///
39
/// # Safety
40
/// `start` and `end` must be within the windows bounds
41
unsafe fn update(&mut self, new_start: usize, new_end: usize);
42
43
/// Get the aggregate of the current window relative to the value at `idx`.
44
fn get_agg(&self, idx: usize) -> Option<Out>;
45
46
/// Returns the length of the underlying input.
47
fn slice_len(&self) -> usize;
48
}
49
50
// Use an aggregation window that maintains the state
51
pub(super) fn rolling_apply_agg_window<Agg, T, O, Fo>(
52
values: &[T],
53
window_size: usize,
54
min_periods: usize,
55
det_offsets_fn: Fo,
56
params: Option<RollingFnParams>,
57
) -> PolarsResult<ArrayRef>
58
where
59
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
60
Agg: RollingAggWindowNoNulls<T, O>,
61
T: Debug + NativeType + Num,
62
O: Debug + NativeType + Num,
63
{
64
let len = values.len();
65
let (start, end) = det_offsets_fn(0, window_size, len);
66
let mut agg_window = Agg::new(values, start, end, params, Some(window_size));
67
let out = (0..len).map(|idx| {
68
let (start, end) = det_offsets_fn(idx, window_size, len);
69
if end - start < min_periods {
70
None
71
} else {
72
// SAFETY:
73
// we are in bounds
74
unsafe { agg_window.update(start, end) }
75
agg_window.get_agg(idx)
76
}
77
});
78
let arr = PrimitiveArray::from_trusted_len_iter(out);
79
Ok(Box::new(arr))
80
}
81
82
pub(super) fn rolling_apply_weights<T, Fo, Fa>(
83
values: &[T],
84
window_size: usize,
85
min_periods: usize,
86
det_offsets_fn: Fo,
87
aggregator: Fa,
88
weights: &[T],
89
centered: bool,
90
) -> PolarsResult<ArrayRef>
91
where
92
T: NativeType + num_traits::Zero + std::ops::Div<Output = T> + Copy,
93
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
94
Fa: Fn(&[T], &[T]) -> T,
95
{
96
assert_eq!(weights.len(), window_size);
97
let len = values.len();
98
let out = (0..len)
99
.map(|idx| {
100
let (start, end) = det_offsets_fn(idx, window_size, len);
101
let vals = unsafe { values.get_unchecked(start..end) };
102
let win_len = end - start;
103
let weights_start = if centered {
104
// When using centered weights, we need to find the right location
105
// in the weights array specifically by aligning the center of the
106
// window with idx, to handle cases where the window is smaller than
107
// weights array.
108
let center = (window_size / 2) as isize;
109
let offset = center - (idx as isize - start as isize);
110
offset.max(0) as usize
111
} else if start == 0 {
112
// When start is 0, we need to work backwards from the end of the
113
// weights array to ensure we are lined up correctly (since the
114
// start of the values array is implicitly cut off)
115
weights.len() - win_len
116
} else {
117
0
118
};
119
let weights_slice = &weights[weights_start..weights_start + win_len];
120
aggregator(vals, weights_slice)
121
})
122
.collect_trusted::<Vec<T>>();
123
124
let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
125
Ok(Box::new(PrimitiveArray::new(
126
ArrowDataType::from(T::PRIMITIVE),
127
out.into(),
128
validity.map(|b| b.into()),
129
)))
130
}
131
132
fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
133
where
134
T: Float + std::ops::AddAssign,
135
{
136
// Compute weighted mean and weighted sum of squares in a single pass
137
let (wssq, wmean, total_weight) = vals.iter().zip(weights).fold(
138
(T::zero(), T::zero(), T::zero()),
139
|(wssq, wsum, wtot), (&v, &w)| (wssq + v * v * w, wsum + v * w, wtot + w),
140
);
141
if total_weight.is_zero() {
142
T::zero() // Will get masked to null.
143
} else {
144
let mean = wmean / total_weight;
145
(wssq / total_weight) - (mean * mean)
146
}
147
}
148
149
pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
150
where
151
T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
152
{
153
values.iter().zip(weights).map(|(v, w)| *v * *w).sum()
154
}
155
156
/// Compute the weighted mean of values, given weights (not necessarily normalized).
157
/// Returns sum_i(values[i] * weights[i]) / sum_i(weights[i])
158
pub(crate) fn compute_mean_weights<T>(values: &[T], weights: &[T]) -> T
159
where
160
T: std::iter::Sum<T>
161
+ Copy
162
+ std::ops::Mul<Output = T>
163
+ std::ops::Div<Output = T>
164
+ num_traits::Zero,
165
{
166
let (weighted_sum, total_weight) = values
167
.iter()
168
.zip(weights)
169
.fold((T::zero(), T::zero()), |(wsum, wtot), (&v, &w)| {
170
(wsum + v * w, wtot + w)
171
});
172
if total_weight.is_zero() {
173
T::zero() // Will get masked to null.
174
} else {
175
weighted_sum / total_weight
176
}
177
}
178
179
pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>
180
where
181
{
182
weights
183
.iter()
184
.map(|v| NumCast::from(*v).unwrap())
185
.collect::<Vec<_>>()
186
}
187
188