Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/nulls/quantile.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use super::*;
3
use crate::rolling::quantile_filter::SealedRolling;
4
5
pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
6
sorted: SortedBufNulls<'a, T>,
7
prob: f64,
8
method: QuantileMethod,
9
}
10
11
impl<
12
'a,
13
T: NativeType
14
+ IsFloat
15
+ Float
16
+ std::iter::Sum
17
+ AddAssign
18
+ SubAssign
19
+ Div<Output = T>
20
+ NumCast
21
+ One
22
+ Zero
23
+ SealedRolling
24
+ PartialOrd
25
+ Sub<Output = T>,
26
> RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>
27
{
28
unsafe fn new(
29
slice: &'a [T],
30
validity: &'a Bitmap,
31
start: usize,
32
end: usize,
33
params: Option<RollingFnParams>,
34
window_size: Option<usize>,
35
) -> Self {
36
let params = params.unwrap();
37
let RollingFnParams::Quantile(params) = params else {
38
unreachable!("expected Quantile params");
39
};
40
Self {
41
sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
42
prob: params.prob,
43
method: params.method,
44
}
45
}
46
47
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
48
let null_count = self.sorted.update(start, end);
49
let mut length = self.sorted.len();
50
// The min periods_issue will be taken care of when actually rolling
51
if null_count == length {
52
return None;
53
}
54
// Nulls are guaranteed to be at the front
55
length -= null_count;
56
let mut idx = match self.method {
57
QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,
58
QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
59
((length as f64 - 1.0) * self.prob).floor() as usize
60
},
61
QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
62
QuantileMethod::Equiprobable => {
63
((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
64
},
65
};
66
67
idx = std::cmp::min(idx, length - 1);
68
69
// we can unwrap because we sliced of the nulls
70
match self.method {
71
QuantileMethod::Midpoint => {
72
let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
73
74
debug_assert!(idx <= top_idx);
75
let v = if idx != top_idx {
76
let mut vals = self
77
.sorted
78
.index_range(idx + null_count..top_idx + null_count + 1);
79
let low = vals.next().unwrap().unwrap();
80
let high = vals.next().unwrap().unwrap();
81
(low + high) / T::from::<f64>(2.0f64).unwrap()
82
} else {
83
self.sorted.get(idx + null_count).unwrap()
84
};
85
86
Some(v)
87
},
88
QuantileMethod::Linear => {
89
let float_idx = (length as f64 - 1.0) * self.prob;
90
let top_idx = f64::ceil(float_idx) as usize;
91
92
if top_idx == idx {
93
Some(self.sorted.get(idx + null_count).unwrap())
94
} else {
95
let mut vals = self
96
.sorted
97
.index_range(idx + null_count..top_idx + null_count + 1);
98
let low = vals.next().unwrap().unwrap();
99
let high = vals.next().unwrap().unwrap();
100
101
let proportion = T::from(float_idx - idx as f64).unwrap();
102
Some(proportion * (high - low) + low)
103
}
104
},
105
_ => Some(self.sorted.get(idx + null_count).unwrap()),
106
}
107
}
108
109
fn is_valid(&self, min_periods: usize) -> bool {
110
self.sorted.is_valid(min_periods)
111
}
112
}
113
114
pub fn rolling_quantile<T>(
115
arr: &PrimitiveArray<T>,
116
window_size: usize,
117
min_periods: usize,
118
center: bool,
119
weights: Option<&[f64]>,
120
params: Option<RollingFnParams>,
121
) -> ArrayRef
122
where
123
T: NativeType
124
+ IsFloat
125
+ Float
126
+ std::iter::Sum
127
+ AddAssign
128
+ SubAssign
129
+ Div<Output = T>
130
+ NumCast
131
+ One
132
+ Zero
133
+ SealedRolling
134
+ PartialOrd
135
+ Sub<Output = T>,
136
{
137
if weights.is_some() {
138
panic!("weights not yet supported on array with null values")
139
}
140
let offset_fn = match center {
141
true => det_offsets_center,
142
false => det_offsets,
143
};
144
/*
145
TODO: fix or remove the dancing links based rolling implementation
146
see https://github.com/pola-rs/polars/issues/23480
147
if !center {
148
let params = params.as_ref().unwrap();
149
let RollingFnParams::Quantile(params) = params else {
150
unreachable!("expected Quantile params");
151
};
152
153
let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
154
params.method,
155
min_periods,
156
window_size,
157
arr.clone(),
158
params.prob,
159
);
160
let out: PrimitiveArray<T> = out.into();
161
return Box::new(out);
162
}
163
*/
164
rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
165
arr.values().as_slice(),
166
arr.validity().as_ref().unwrap(),
167
window_size,
168
min_periods,
169
offset_fn,
170
params,
171
)
172
}
173
174
#[cfg(test)]
175
mod test {
176
use arrow::buffer::Buffer;
177
use arrow::datatypes::ArrowDataType;
178
179
use super::*;
180
181
#[test]
182
fn test_rolling_median_nulls() {
183
let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
184
let arr = &PrimitiveArray::new(
185
ArrowDataType::Float64,
186
buf,
187
Some(Bitmap::from(&[true, false, true, true])),
188
);
189
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
190
prob: 0.5,
191
method: QuantileMethod::Linear,
192
}));
193
194
let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
195
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
196
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
197
assert_eq!(out, &[None, None, None, Some(3.5)]);
198
199
let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
200
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
201
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
202
assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
203
204
let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
205
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
206
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
207
assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
208
209
let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
210
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
211
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
212
assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
213
214
let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
215
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
216
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
217
assert_eq!(out, &[None, None, None, None]);
218
}
219
220
#[test]
221
fn test_rolling_quantile_nulls_limits() {
222
// compare quantiles to corresponding min/max/median values
223
let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
224
let values = &PrimitiveArray::new(
225
ArrowDataType::Float64,
226
buf,
227
Some(Bitmap::from(&[true, false, false, true, true])),
228
);
229
230
let methods = vec![
231
QuantileMethod::Lower,
232
QuantileMethod::Higher,
233
QuantileMethod::Nearest,
234
QuantileMethod::Midpoint,
235
QuantileMethod::Linear,
236
QuantileMethod::Equiprobable,
237
];
238
239
for method in methods {
240
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
241
prob: 0.0,
242
method,
243
}));
244
let out1 = rolling_min(values, 2, 1, false, None, None);
245
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
246
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
247
let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
248
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
249
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
250
assert_eq!(out1, out2);
251
252
let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
253
prob: 1.0,
254
method,
255
}));
256
let out1 = rolling_max(values, 2, 1, false, None, None);
257
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
258
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
259
let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
260
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
261
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
262
assert_eq!(out1, out2);
263
}
264
}
265
}
266
267