Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/rolling/mod.rs
8421 views
1
mod mean;
2
mod min_max;
3
mod moment;
4
pub mod no_nulls;
5
pub mod nulls;
6
pub mod quantile_filter;
7
mod rank;
8
mod sum;
9
10
mod arg_min_max;
11
pub(super) mod window;
12
use std::hash::Hash;
13
use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
14
15
pub use arg_min_max::{ArgMaxWindow, ArgMinMaxWindow, ArgMinWindow};
16
use arrow::array::{ArrayRef, PrimitiveArray};
17
use arrow::bitmap::{Bitmap, MutableBitmap};
18
use arrow::types::NativeType;
19
pub use mean::MeanWindow;
20
use num_traits::{Bounded, Float, NumCast, One, Zero};
21
use polars_utils::float::IsFloat;
22
#[cfg(feature = "serde")]
23
use serde::{Deserialize, Serialize};
24
use strum_macros::IntoStaticStr;
25
pub use sum::SumWindow;
26
use window::*;
27
28
type Start = usize;
29
type End = usize;
30
type Idx = usize;
31
type WindowSize = usize;
32
type Len = usize;
33
34
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
35
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
37
#[strum(serialize_all = "snake_case")]
38
pub enum QuantileMethod {
39
#[default]
40
Nearest,
41
Lower,
42
Higher,
43
Midpoint,
44
Linear,
45
Equiprobable,
46
}
47
48
#[deprecated(note = "use QuantileMethod instead")]
49
pub type QuantileInterpolOptions = QuantileMethod;
50
51
#[derive(Clone, Copy, Debug, PartialEq, Hash)]
52
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
54
pub enum RollingFnParams {
55
Quantile(RollingQuantileParams),
56
Var(RollingVarParams),
57
Rank {
58
method: RollingRankMethod,
59
seed: Option<u64>,
60
},
61
Skew {
62
bias: bool,
63
},
64
Kurtosis {
65
fisher: bool,
66
bias: bool,
67
},
68
}
69
70
fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
71
(i.saturating_sub(window_size - 1), i + 1)
72
}
73
fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {
74
let right_window = window_size.div_ceil(2);
75
(
76
i.saturating_sub(window_size - right_window),
77
std::cmp::min(len, i + right_window),
78
)
79
}
80
81
fn create_validity<Fo>(
82
min_periods: usize,
83
len: usize,
84
window_size: usize,
85
det_offsets_fn: Fo,
86
) -> Option<MutableBitmap>
87
where
88
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
89
{
90
if min_periods > 1 {
91
let mut validity = MutableBitmap::with_capacity(len);
92
validity.extend_constant(len, true);
93
94
// Set the null values at the boundaries
95
96
// Head.
97
for i in 0..len {
98
let (start, end) = det_offsets_fn(i, window_size, len);
99
if (end - start) < min_periods {
100
validity.set(i, false)
101
} else {
102
break;
103
}
104
}
105
// Tail.
106
for i in (0..len).rev() {
107
let (start, end) = det_offsets_fn(i, window_size, len);
108
if (end - start) < min_periods {
109
validity.set(i, false)
110
} else {
111
break;
112
}
113
}
114
115
Some(validity)
116
} else {
117
None
118
}
119
}
120
121
// Parameters allowed for rolling operations.
122
#[derive(Clone, Copy, Debug, PartialEq, Hash)]
123
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
124
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
125
pub struct RollingVarParams {
126
pub ddof: u8,
127
}
128
129
#[derive(Clone, Copy, Debug, PartialEq)]
130
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
131
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
132
pub struct RollingQuantileParams {
133
pub prob: f64,
134
pub method: QuantileMethod,
135
}
136
137
impl Hash for RollingQuantileParams {
138
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139
// Will not be NaN, so hash + eq symmetry will hold.
140
self.prob.to_bits().hash(state);
141
self.method.hash(state);
142
}
143
}
144
145
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
146
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
147
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
148
#[strum(serialize_all = "snake_case")]
149
pub enum RollingRankMethod {
150
#[default]
151
Average,
152
Min,
153
Max,
154
Dense,
155
Random,
156
}
157
158