Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/no_nulls/quantile.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::legacy::utils::CustomIterTools;
3
use num_traits::ToPrimitive;
4
use polars_error::polars_ensure;
5
6
use super::QuantileMethod::*;
7
use super::*;
8
use crate::rolling::quantile_filter::SealedRolling;
9
10
pub struct QuantileWindow<'a, T: NativeType> {
11
sorted: SortedBuf<'a, T>,
12
prob: f64,
13
method: QuantileMethod,
14
}
15
16
impl<
17
'a,
18
T: NativeType
19
+ Float
20
+ std::iter::Sum
21
+ AddAssign
22
+ SubAssign
23
+ Div<Output = T>
24
+ NumCast
25
+ One
26
+ Zero
27
+ SealedRolling
28
+ Sub<Output = T>,
29
> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
30
{
31
fn new(
32
slice: &'a [T],
33
start: usize,
34
end: usize,
35
params: Option<RollingFnParams>,
36
window_size: Option<usize>,
37
) -> Self {
38
let params = params.unwrap();
39
let RollingFnParams::Quantile(params) = params else {
40
unreachable!("expected Quantile params");
41
};
42
43
Self {
44
sorted: SortedBuf::new(slice, start, end, window_size),
45
prob: params.prob,
46
method: params.method,
47
}
48
}
49
50
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
51
self.sorted.update(start, end);
52
let length = self.sorted.len();
53
54
let idx = match self.method {
55
Linear => {
56
// Maybe add a fast path for median case? They could branch depending on odd/even.
57
let length_f = length as f64;
58
let idx = ((length_f - 1.0) * self.prob).floor() as usize;
59
60
let float_idx_top = (length_f - 1.0) * self.prob;
61
let top_idx = float_idx_top.ceil() as usize;
62
return if idx == top_idx {
63
Some(self.sorted.get(idx))
64
} else {
65
let proportion = T::from(float_idx_top - idx as f64).unwrap();
66
let mut vals = self.sorted.index_range(idx..top_idx + 1);
67
let vi = *vals.next().unwrap();
68
let vj = *vals.next().unwrap();
69
70
Some(proportion * (vj - vi) + vi)
71
};
72
},
73
Midpoint => {
74
let length_f = length as f64;
75
let idx = (length_f * self.prob) as usize;
76
let idx = std::cmp::min(idx, length - 1);
77
78
let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
79
return if top_idx == idx {
80
Some(self.sorted.get(idx))
81
} else {
82
let top_idx = idx + 1;
83
let mut vals = self.sorted.index_range(idx..top_idx + 1);
84
let mid = *vals.next().unwrap();
85
let mid_plus_1 = *vals.next().unwrap();
86
87
Some((mid + mid_plus_1) / (T::one() + T::one()))
88
};
89
},
90
Nearest => {
91
let idx = (((length as f64) - 1.0) * self.prob).round() as usize;
92
std::cmp::min(idx, length - 1)
93
},
94
Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,
95
Higher => {
96
let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
97
std::cmp::min(idx, length - 1)
98
},
99
Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
100
};
101
102
Some(self.sorted.get(idx))
103
}
104
}
105
106
pub fn rolling_quantile<T>(
107
values: &[T],
108
window_size: usize,
109
min_periods: usize,
110
center: bool,
111
weights: Option<&[f64]>,
112
params: Option<RollingFnParams>,
113
) -> PolarsResult<ArrayRef>
114
where
115
T: NativeType
116
+ IsFloat
117
+ Float
118
+ std::iter::Sum
119
+ AddAssign
120
+ SubAssign
121
+ Div<Output = T>
122
+ NumCast
123
+ One
124
+ Zero
125
+ SealedRolling
126
+ PartialOrd
127
+ Sub<Output = T>,
128
{
129
let offset_fn = match center {
130
true => det_offsets_center,
131
false => det_offsets,
132
};
133
match weights {
134
None => {
135
if !center {
136
let params = params.as_ref().unwrap();
137
let RollingFnParams::Quantile(params) = params else {
138
unreachable!("expected Quantile params");
139
};
140
let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
141
params.method,
142
min_periods,
143
window_size,
144
values,
145
params.prob,
146
);
147
let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
148
return Ok(Box::new(PrimitiveArray::new(
149
T::PRIMITIVE.into(),
150
out.into(),
151
validity.map(|b| b.into()),
152
)));
153
}
154
155
rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
156
values,
157
window_size,
158
min_periods,
159
offset_fn,
160
params,
161
)
162
},
163
Some(weights) => {
164
let wsum = weights.iter().sum();
165
polars_ensure!(
166
wsum != 0.0,
167
ComputeError: "Weighted quantile is undefined if weights sum to 0"
168
);
169
let params = params.unwrap();
170
let RollingFnParams::Quantile(params) = params else {
171
unreachable!("expected Quantile params");
172
};
173
174
Ok(rolling_apply_weighted_quantile(
175
values,
176
params.prob,
177
params.method,
178
window_size,
179
min_periods,
180
offset_fn,
181
weights,
182
wsum,
183
))
184
},
185
}
186
}
187
188
#[inline]
189
fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
190
where
191
T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
192
{
193
// There are a few ways to compute a weighted quantile but no "canonical" way.
194
// This is mostly taken from the Julia implementation which was readable and reasonable
195
// https://juliastats.org/StatsBase.jl/stable/scalarstats/#Quantile-and-Related-Functions-1
196
let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());
197
198
// Once the cumulative weight crosses h, we've found our ind{ex/ices}. The definition may look
199
// odd but it's the equivalent of taking h = p * (n - 1) + 1 if your data is indexed from 1.
200
let h: f64 = p * (wsum - buf[0].1) + buf[0].1;
201
for &(v, w) in buf.iter() {
202
if s > h {
203
break;
204
}
205
(s_old, v_old, vk) = (s, vk, v);
206
s += w;
207
}
208
match (h == s_old, method) {
209
(true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter
210
(_, Lower) => v_old,
211
(_, Higher) => vk,
212
(_, Nearest) => {
213
if s - h > h - s_old {
214
v_old
215
} else {
216
vk
217
}
218
},
219
(_, Equiprobable) => {
220
let threshold = (wsum * p).ceil() - 1.0;
221
if s > threshold { vk } else { v_old }
222
},
223
(_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
224
// This is seemingly the canonical way to do it.
225
(_, Linear) => {
226
v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)
227
},
228
}
229
}
230
231
#[allow(clippy::too_many_arguments)]
232
fn rolling_apply_weighted_quantile<T, Fo>(
233
values: &[T],
234
p: f64,
235
method: QuantileMethod,
236
window_size: usize,
237
min_periods: usize,
238
det_offsets_fn: Fo,
239
weights: &[f64],
240
wsum: f64,
241
) -> ArrayRef
242
where
243
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
244
T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
245
{
246
assert_eq!(weights.len(), window_size);
247
// Keep nonzero weights and their indices to know which values we need each iteration.
248
let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();
249
let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];
250
let len = values.len();
251
let out = (0..len)
252
.map(|idx| {
253
// Don't need end. Window size is constant and we computed offsets from start above.
254
let (start, _) = det_offsets_fn(idx, window_size, len);
255
256
// Sorting is not ideal, see https://github.com/tobiasschoch/wquantile for something faster
257
unsafe {
258
buf.iter_mut()
259
.zip(nz_idx_wts.iter())
260
.for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
261
}
262
buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
263
compute_wq(&buf, p, wsum, method)
264
})
265
.collect_trusted::<Vec<T>>();
266
267
let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
268
Box::new(PrimitiveArray::new(
269
T::PRIMITIVE.into(),
270
out.into(),
271
validity.map(|b| b.into()),
272
))
273
}
274
275
#[cfg(test)]
276
mod test {
277
use super::*;
278
279
#[test]
280
fn test_rolling_median() {
281
let values = &[1.0, 2.0, 3.0, 4.0];
282
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
283
prob: 0.5,
284
method: Linear,
285
}));
286
let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap();
287
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
288
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
289
assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);
290
291
let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap();
292
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
293
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
294
assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);
295
296
let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap();
297
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
298
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
299
assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);
300
301
let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap();
302
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
303
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
304
assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);
305
306
let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap();
307
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
308
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
309
assert_eq!(out, &[None, None, Some(2.5), None]);
310
}
311
312
#[test]
313
fn test_rolling_quantile_limits() {
314
let values = &[1.0f64, 2.0, 3.0, 4.0];
315
316
let methods = vec![
317
QuantileMethod::Lower,
318
QuantileMethod::Higher,
319
QuantileMethod::Nearest,
320
QuantileMethod::Midpoint,
321
QuantileMethod::Linear,
322
QuantileMethod::Equiprobable,
323
];
324
325
for method in methods {
326
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
327
prob: 0.0,
328
method,
329
}));
330
let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
331
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
332
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
333
let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();
334
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
335
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
336
assert_eq!(out1, out2);
337
338
let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
339
prob: 1.0,
340
method,
341
}));
342
let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
343
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
344
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
345
let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();
346
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
347
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
348
assert_eq!(out1, out2);
349
}
350
}
351
}
352
353