Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/mod.rs
6939 views
1
mod min_max;
2
pub mod moment;
3
pub mod no_nulls;
4
pub mod nulls;
5
pub mod quantile_filter;
6
pub(super) mod window;
7
use std::hash::Hash;
8
use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
9
10
use arrow::array::{ArrayRef, PrimitiveArray};
11
use arrow::bitmap::{Bitmap, MutableBitmap};
12
use arrow::types::NativeType;
13
use num_traits::{Bounded, Float, NumCast, One, Zero};
14
use polars_utils::float::IsFloat;
15
#[cfg(feature = "serde")]
16
use serde::{Deserialize, Serialize};
17
use strum_macros::IntoStaticStr;
18
use window::*;
19
20
type Start = usize;
21
type End = usize;
22
type Idx = usize;
23
type WindowSize = usize;
24
type Len = usize;
25
26
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
27
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
29
#[strum(serialize_all = "snake_case")]
30
pub enum QuantileMethod {
31
#[default]
32
Nearest,
33
Lower,
34
Higher,
35
Midpoint,
36
Linear,
37
Equiprobable,
38
}
39
40
#[deprecated(note = "use QuantileMethod instead")]
41
pub type QuantileInterpolOptions = QuantileMethod;
42
43
#[derive(Clone, Copy, Debug, PartialEq, Hash)]
44
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
46
pub enum RollingFnParams {
47
Quantile(RollingQuantileParams),
48
Var(RollingVarParams),
49
Skew { bias: bool },
50
Kurtosis { fisher: bool, bias: bool },
51
}
52
53
fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
54
(i.saturating_sub(window_size - 1), i + 1)
55
}
56
fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {
57
let right_window = window_size.div_ceil(2);
58
(
59
i.saturating_sub(window_size - right_window),
60
std::cmp::min(len, i + right_window),
61
)
62
}
63
64
fn create_validity<Fo>(
65
min_periods: usize,
66
len: usize,
67
window_size: usize,
68
det_offsets_fn: Fo,
69
) -> Option<MutableBitmap>
70
where
71
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
72
{
73
if min_periods > 1 {
74
let mut validity = MutableBitmap::with_capacity(len);
75
validity.extend_constant(len, true);
76
77
// Set the null values at the boundaries
78
79
// Head.
80
for i in 0..len {
81
let (start, end) = det_offsets_fn(i, window_size, len);
82
if (end - start) < min_periods {
83
validity.set(i, false)
84
} else {
85
break;
86
}
87
}
88
// Tail.
89
for i in (0..len).rev() {
90
let (start, end) = det_offsets_fn(i, window_size, len);
91
if (end - start) < min_periods {
92
validity.set(i, false)
93
} else {
94
break;
95
}
96
}
97
98
Some(validity)
99
} else {
100
None
101
}
102
}
103
104
// Parameters allowed for rolling operations.
105
#[derive(Clone, Copy, Debug, PartialEq, Hash)]
106
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
107
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
108
pub struct RollingVarParams {
109
pub ddof: u8,
110
}
111
112
#[derive(Clone, Copy, Debug, PartialEq)]
113
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
114
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
115
pub struct RollingQuantileParams {
116
pub prob: f64,
117
pub method: QuantileMethod,
118
}
119
120
impl Hash for RollingQuantileParams {
121
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
122
// Will not be NaN, so hash + eq symmetry will hold.
123
self.prob.to_bits().hash(state);
124
self.method.hash(state);
125
}
126
}
127
128