Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/quantile.rs
8424 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::legacy::utils::CustomIterTools;
3
use num_traits::ToPrimitive;
4
use polars_error::polars_ensure;
5
6
use super::QuantileMethod::*;
7
use super::*;
8
use crate::rolling::quantile_filter::SealedRolling;
9
10
pub struct QuantileWindow<'a, T: NativeType> {
11
sorted: SortedBuf<'a, T>,
12
prob: f64,
13
method: QuantileMethod,
14
}
15
16
impl<
17
T: NativeType
18
+ Float
19
+ std::iter::Sum
20
+ AddAssign
21
+ SubAssign
22
+ Div<Output = T>
23
+ NumCast
24
+ One
25
+ Zero
26
+ SealedRolling
27
+ Sub<Output = T>,
28
> RollingAggWindowNoNulls<T> for QuantileWindow<'_, T>
29
{
30
type This<'a> = QuantileWindow<'a, T>;
31
32
fn new<'a>(
33
slice: &'a [T],
34
start: usize,
35
end: usize,
36
params: Option<RollingFnParams>,
37
window_size: Option<usize>,
38
) -> Self::This<'a> {
39
let params = params.unwrap();
40
let RollingFnParams::Quantile(params) = params else {
41
unreachable!("expected Quantile params");
42
};
43
44
QuantileWindow {
45
sorted: SortedBuf::new(slice, start, end, window_size),
46
prob: params.prob,
47
method: params.method,
48
}
49
}
50
51
unsafe fn update(&mut self, start: usize, end: usize) {
52
self.sorted.update(start, end);
53
}
54
55
fn get_agg(&self, _idx: usize) -> Option<T> {
56
let length = self.sorted.len();
57
if length == 0 {
58
return None;
59
}
60
let idx = match self.method {
61
Linear => {
62
// Maybe add a fast path for median case? They could branch depending on odd/even.
63
let length_f = length as f64;
64
let idx = ((length_f - 1.0) * self.prob).floor() as usize;
65
66
let float_idx_top = (length_f - 1.0) * self.prob;
67
let top_idx = float_idx_top.ceil() as usize;
68
return if idx == top_idx {
69
Some(self.sorted.get(idx))
70
} else {
71
let proportion = T::from(float_idx_top - idx as f64).unwrap();
72
let vi = self.sorted.get(idx);
73
let vj = self.sorted.get(idx + 1);
74
Some(proportion * (vj - vi) + vi)
75
};
76
},
77
Midpoint => {
78
let length_f = length as f64;
79
80
let idx = ((length_f - 1.0) * self.prob).floor() as usize;
81
let idx = std::cmp::min(idx, length - 1);
82
let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
83
84
return if top_idx == idx {
85
Some(self.sorted.get(idx))
86
} else {
87
let mid = self.sorted.get(idx);
88
let mid_plus_1 = self.sorted.get(idx + 1);
89
Some((mid + mid_plus_1) / (T::one() + T::one()))
90
};
91
},
92
Nearest => {
93
let idx = (((length as f64) - 1.0) * self.prob).round() as usize;
94
std::cmp::min(idx, length - 1)
95
},
96
Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,
97
Higher => {
98
let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
99
std::cmp::min(idx, length - 1)
100
},
101
Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
102
};
103
104
Some(self.sorted.get(idx))
105
}
106
107
fn slice_len(&self) -> usize {
108
self.sorted.slice_len()
109
}
110
}
111
112
pub fn rolling_quantile<T>(
113
values: &[T],
114
window_size: usize,
115
min_periods: usize,
116
center: bool,
117
weights: Option<&[f64]>,
118
params: Option<RollingFnParams>,
119
) -> PolarsResult<ArrayRef>
120
where
121
T: NativeType
122
+ IsFloat
123
+ Float
124
+ std::iter::Sum
125
+ AddAssign
126
+ SubAssign
127
+ Div<Output = T>
128
+ NumCast
129
+ One
130
+ Zero
131
+ SealedRolling
132
+ PartialOrd
133
+ Sub<Output = T>,
134
{
135
let offset_fn = match center {
136
true => det_offsets_center,
137
false => det_offsets,
138
};
139
match weights {
140
None => {
141
if !center {
142
let params = params.as_ref().unwrap();
143
let RollingFnParams::Quantile(params) = params else {
144
unreachable!("expected Quantile params");
145
};
146
let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
147
params.method,
148
min_periods,
149
window_size,
150
values,
151
params.prob,
152
);
153
let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
154
return Ok(Box::new(PrimitiveArray::new(
155
T::PRIMITIVE.into(),
156
out.into(),
157
validity.map(|b| b.into()),
158
)));
159
}
160
161
rolling_apply_agg_window::<QuantileWindow<_>, _, _, _>(
162
values,
163
window_size,
164
min_periods,
165
offset_fn,
166
params,
167
)
168
},
169
Some(weights) => {
170
let wsum = weights.iter().sum();
171
polars_ensure!(
172
wsum != 0.0,
173
ComputeError: "Weighted quantile is undefined if weights sum to 0"
174
);
175
let params = params.unwrap();
176
let RollingFnParams::Quantile(params) = params else {
177
unreachable!("expected Quantile params");
178
};
179
180
Ok(rolling_apply_weighted_quantile(
181
values,
182
params.prob,
183
params.method,
184
window_size,
185
min_periods,
186
offset_fn,
187
weights,
188
wsum,
189
))
190
},
191
}
192
}
193
194
#[inline]
195
fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
196
where
197
T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
198
{
199
// There are a few ways to compute a weighted quantile but no "canonical" way.
200
// This is mostly taken from the Julia implementation which was readable and reasonable
201
// https://juliastats.org/StatsBase.jl/stable/scalarstats/#Quantile-and-Related-Functions-1
202
let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());
203
204
// Once the cumulative weight crosses h, we've found our ind{ex/ices}. The definition may look
205
// odd but it's the equivalent of taking h = p * (n - 1) + 1 if your data is indexed from 1.
206
let h: f64 = p * (wsum - buf[0].1) + buf[0].1;
207
for &(v, w) in buf.iter() {
208
if s > h {
209
break;
210
}
211
(s_old, v_old, vk) = (s, vk, v);
212
s += w;
213
}
214
match (h == s_old, method) {
215
(true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter
216
(_, Lower) => v_old,
217
(_, Higher) => vk,
218
(_, Nearest) => {
219
if s - h > h - s_old {
220
v_old
221
} else {
222
vk
223
}
224
},
225
(_, Equiprobable) => {
226
let threshold = (wsum * p).ceil() - 1.0;
227
if s > threshold { vk } else { v_old }
228
},
229
(_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
230
// This is seemingly the canonical way to do it.
231
(_, Linear) => {
232
v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)
233
},
234
}
235
}
236
237
#[allow(clippy::too_many_arguments)]
238
fn rolling_apply_weighted_quantile<T, Fo>(
239
values: &[T],
240
p: f64,
241
method: QuantileMethod,
242
window_size: usize,
243
min_periods: usize,
244
det_offsets_fn: Fo,
245
weights: &[f64],
246
wsum: f64,
247
) -> ArrayRef
248
where
249
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
250
T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
251
{
252
assert_eq!(weights.len(), window_size);
253
// Keep nonzero weights and their indices to know which values we need each iteration.
254
let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();
255
let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];
256
let len = values.len();
257
let out = (0..len)
258
.map(|idx| {
259
// Don't need end. Window size is constant and we computed offsets from start above.
260
let (start, _) = det_offsets_fn(idx, window_size, len);
261
262
// Sorting is not ideal, see https://github.com/tobiasschoch/wquantile for something faster
263
unsafe {
264
buf.iter_mut()
265
.zip(nz_idx_wts.iter())
266
.for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
267
}
268
buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
269
compute_wq(&buf, p, wsum, method)
270
})
271
.collect_trusted::<Vec<T>>();
272
273
let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
274
Box::new(PrimitiveArray::new(
275
T::PRIMITIVE.into(),
276
out.into(),
277
validity.map(|b| b.into()),
278
))
279
}
280
281
#[cfg(test)]
282
mod test {
283
use super::*;
284
285
#[test]
286
fn test_rolling_median() {
287
let values = &[1.0, 2.0, 3.0, 4.0];
288
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
289
prob: 0.5,
290
method: Linear,
291
}));
292
let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap();
293
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
294
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
295
assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);
296
297
let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap();
298
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
299
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
300
assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);
301
302
let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap();
303
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
304
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
305
assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);
306
307
let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap();
308
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
309
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
310
assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);
311
312
let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap();
313
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
314
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
315
assert_eq!(out, &[None, None, Some(2.5), None]);
316
}
317
318
#[test]
319
fn test_rolling_quantile_limits() {
320
let values = &[1.0f64, 2.0, 3.0, 4.0];
321
322
let methods = vec![
323
QuantileMethod::Lower,
324
QuantileMethod::Higher,
325
QuantileMethod::Nearest,
326
QuantileMethod::Midpoint,
327
QuantileMethod::Linear,
328
QuantileMethod::Equiprobable,
329
];
330
331
for method in methods {
332
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
333
prob: 0.0,
334
method,
335
}));
336
let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
337
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
338
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
339
let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();
340
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
341
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
342
assert_eq!(out1, out2);
343
344
let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
345
prob: 1.0,
346
method,
347
}));
348
let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
349
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
350
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
351
let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();
352
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
353
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
354
assert_eq!(out1, out2);
355
}
356
}
357
}
358
359