Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/quantile.rs
8421 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use super::*;
3
use crate::rolling::quantile_filter::SealedRolling;
4
5
pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
6
sorted: SortedBufNulls<'a, T>,
7
prob: f64,
8
method: QuantileMethod,
9
}
10
11
impl<
12
T: NativeType
13
+ IsFloat
14
+ Float
15
+ std::iter::Sum
16
+ AddAssign
17
+ SubAssign
18
+ Div<Output = T>
19
+ NumCast
20
+ One
21
+ Zero
22
+ SealedRolling
23
+ PartialOrd
24
+ Sub<Output = T>,
25
> RollingAggWindowNulls<T> for QuantileWindow<'_, T>
26
{
27
type This<'a> = QuantileWindow<'a, T>;
28
29
fn new<'a>(
30
slice: &'a [T],
31
validity: &'a Bitmap,
32
start: usize,
33
end: usize,
34
params: Option<RollingFnParams>,
35
window_size: Option<usize>,
36
) -> Self::This<'a> {
37
let params = params.unwrap();
38
let RollingFnParams::Quantile(params) = params else {
39
unreachable!("expected Quantile params");
40
};
41
QuantileWindow {
42
sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
43
prob: params.prob,
44
method: params.method,
45
}
46
}
47
48
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
49
self.sorted.update(new_start, new_end);
50
}
51
52
fn get_agg(&self, _idx: usize) -> Option<T> {
53
let mut length = self.sorted.len();
54
let null_count = self.sorted.null_count;
55
56
// The min periods_issue will be taken care of when actually rolling
57
if null_count == length {
58
return None;
59
}
60
// Nulls are guaranteed to be at the front
61
length -= null_count;
62
let mut idx = match self.method {
63
QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,
64
QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
65
((length as f64 - 1.0) * self.prob).floor() as usize
66
},
67
QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
68
QuantileMethod::Equiprobable => {
69
((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
70
},
71
};
72
73
idx = std::cmp::min(idx, length - 1);
74
75
// we can unwrap because we sliced of the nulls
76
match self.method {
77
QuantileMethod::Midpoint => {
78
let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
79
80
debug_assert!(idx <= top_idx);
81
let v = if idx != top_idx {
82
let low = self.sorted.get(idx + null_count).unwrap();
83
let high = self.sorted.get(idx + null_count + 1).unwrap();
84
(low + high) / T::from::<f64>(2.0f64).unwrap()
85
} else {
86
self.sorted.get(idx + null_count).unwrap()
87
};
88
89
Some(v)
90
},
91
QuantileMethod::Linear => {
92
let float_idx = (length as f64 - 1.0) * self.prob;
93
let top_idx = f64::ceil(float_idx) as usize;
94
95
if top_idx == idx {
96
Some(self.sorted.get(idx + null_count).unwrap())
97
} else {
98
let low = self.sorted.get(idx + null_count).unwrap();
99
let high = self.sorted.get(top_idx + null_count).unwrap();
100
let proportion = T::from(float_idx - idx as f64).unwrap();
101
Some(proportion * (high - low) + low)
102
}
103
},
104
_ => Some(self.sorted.get(idx + null_count).unwrap()),
105
}
106
}
107
108
fn is_valid(&self, min_periods: usize) -> bool {
109
self.sorted.is_valid(min_periods)
110
}
111
112
fn slice_len(&self) -> usize {
113
self.sorted.slice_len()
114
}
115
}
116
117
pub fn rolling_quantile<T>(
118
arr: &PrimitiveArray<T>,
119
window_size: usize,
120
min_periods: usize,
121
center: bool,
122
weights: Option<&[f64]>,
123
params: Option<RollingFnParams>,
124
) -> ArrayRef
125
where
126
T: NativeType
127
+ IsFloat
128
+ Float
129
+ std::iter::Sum
130
+ AddAssign
131
+ SubAssign
132
+ Div<Output = T>
133
+ NumCast
134
+ One
135
+ Zero
136
+ SealedRolling
137
+ PartialOrd
138
+ Sub<Output = T>,
139
{
140
if weights.is_some() {
141
panic!("weights not yet supported on array with null values")
142
}
143
let offset_fn = match center {
144
true => det_offsets_center,
145
false => det_offsets,
146
};
147
/*
148
TODO: fix or remove the dancing links based rolling implementation
149
see https://github.com/pola-rs/polars/issues/23480
150
if !center {
151
let params = params.as_ref().unwrap();
152
let RollingFnParams::Quantile(params) = params else {
153
unreachable!("expected Quantile params");
154
};
155
156
let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
157
params.method,
158
min_periods,
159
window_size,
160
arr.clone(),
161
params.prob,
162
);
163
let out: PrimitiveArray<T> = out.into();
164
return Box::new(out);
165
}
166
*/
167
rolling_apply_agg_window::<QuantileWindow<T>, _, _, _>(
168
arr.values().as_slice(),
169
arr.validity().as_ref().unwrap(),
170
window_size,
171
min_periods,
172
offset_fn,
173
params,
174
)
175
}
176
177
#[cfg(test)]
178
mod test {
179
use arrow::datatypes::ArrowDataType;
180
use polars_buffer::Buffer;
181
182
use super::*;
183
184
#[test]
185
fn test_rolling_median_nulls() {
186
let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
187
let arr = &PrimitiveArray::new(
188
ArrowDataType::Float64,
189
buf,
190
Some(Bitmap::from(&[true, false, true, true])),
191
);
192
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
193
prob: 0.5,
194
method: QuantileMethod::Linear,
195
}));
196
197
let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
198
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
199
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
200
assert_eq!(out, &[None, None, None, Some(3.5)]);
201
202
let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
203
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
204
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
205
assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
206
207
let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
208
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
209
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
210
assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
211
212
let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
213
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
214
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
215
assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
216
217
let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
218
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
219
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
220
assert_eq!(out, &[None, None, None, None]);
221
}
222
223
#[test]
224
fn test_rolling_quantile_nulls_limits() {
225
// compare quantiles to corresponding min/max/median values
226
let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
227
let values = &PrimitiveArray::new(
228
ArrowDataType::Float64,
229
buf,
230
Some(Bitmap::from(&[true, false, false, true, true])),
231
);
232
233
let methods = vec![
234
QuantileMethod::Lower,
235
QuantileMethod::Higher,
236
QuantileMethod::Nearest,
237
QuantileMethod::Midpoint,
238
QuantileMethod::Linear,
239
QuantileMethod::Equiprobable,
240
];
241
242
for method in methods {
243
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
244
prob: 0.0,
245
method,
246
}));
247
let out1 = rolling_min(values, 2, 1, false, None, None);
248
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
249
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
250
let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
251
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
252
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
253
assert_eq!(out1, out2);
254
255
let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
256
prob: 1.0,
257
method,
258
}));
259
let out1 = rolling_max(values, 2, 1, false, None, None);
260
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
261
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
262
let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
263
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
264
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
265
assert_eq!(out1, out2);
266
}
267
}
268
}
269
270