Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/dispatch/rolling.rs
7884 views
1
use std::ops::BitAnd;
2
3
use arrow::temporal_conversions::MICROSECONDS_IN_DAY as US_IN_DAY;
4
use polars_core::error::PolarsResult;
5
use polars_core::prelude::{
6
AnyValue, ChunkCast, Column, DataType, IntoColumn, NamedFrom, RollingOptionsFixedWindow,
7
TimeUnit,
8
};
9
use polars_core::scalar::Scalar;
10
use polars_core::series::Series;
11
#[cfg(feature = "cov")]
12
use polars_plan::dsl::RollingCovOptions;
13
use polars_plan::prelude::PlanCallback;
14
use polars_time::prelude::SeriesOpsTime;
15
use polars_utils::pl_str::PlSmallStr;
16
17
fn roll_with_temporal_conversion<F: FnOnce(&Series) -> PolarsResult<Series>>(
18
s: &Column,
19
op: F,
20
) -> PolarsResult<Column> {
21
let dt = s.dtype();
22
let s = if dt.is_temporal() {
23
&s.to_physical_repr()
24
} else {
25
s
26
};
27
28
// @scalar-opt
29
let out = op(s.as_materialized_series())?;
30
31
Ok(match dt {
32
DataType::Date => (out * US_IN_DAY as f64)
33
.cast(&DataType::Int64)?
34
.into_datetime(TimeUnit::Microseconds, None),
35
DataType::Datetime(tu, tz) => out.cast(&DataType::Int64)?.into_datetime(*tu, tz.clone()),
36
DataType::Duration(tu) => out.cast(&DataType::Int64)?.into_duration(*tu),
37
DataType::Time => out.cast(&DataType::Int64)?.into_time(),
38
_ => out,
39
}
40
.into_column())
41
}
42
43
pub(super) fn rolling_min(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
44
// @scalar-opt
45
s.as_materialized_series()
46
.rolling_min(options)
47
.map(Column::from)
48
}
49
50
pub(super) fn rolling_max(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
51
// @scalar-opt
52
s.as_materialized_series()
53
.rolling_max(options)
54
.map(Column::from)
55
}
56
57
pub(super) fn rolling_mean(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
58
roll_with_temporal_conversion(s, |s| s.rolling_mean(options))
59
}
60
61
pub(super) fn rolling_sum(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
62
// @scalar-opt
63
s.as_materialized_series()
64
.rolling_sum(options)
65
.map(Column::from)
66
}
67
68
pub(super) fn rolling_quantile(
69
s: &Column,
70
options: RollingOptionsFixedWindow,
71
) -> PolarsResult<Column> {
72
roll_with_temporal_conversion(s, |s| s.rolling_quantile(options))
73
}
74
75
pub(super) fn rolling_var(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
76
// @scalar-opt
77
s.as_materialized_series()
78
.rolling_var(options)
79
.map(Column::from)
80
}
81
82
pub(super) fn rolling_std(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
83
// @scalar-opt
84
s.as_materialized_series()
85
.rolling_std(options)
86
.map(Column::from)
87
}
88
89
pub(super) fn rolling_rank(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
90
// @scalar-opt
91
s.as_materialized_series()
92
.rolling_rank(options)
93
.map(Column::from)
94
}
95
96
#[cfg(feature = "moment")]
97
pub(super) fn rolling_skew(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
98
// @scalar-opt
99
let s = s.as_materialized_series();
100
polars_ops::series::rolling_skew(s, options).map(Column::from)
101
}
102
103
#[cfg(feature = "moment")]
104
pub(super) fn rolling_kurtosis(
105
s: &Column,
106
options: RollingOptionsFixedWindow,
107
) -> PolarsResult<Column> {
108
// @scalar-opt
109
let s = s.as_materialized_series();
110
polars_ops::series::rolling_kurtosis(s, options).map(Column::from)
111
}
112
113
#[cfg(feature = "cov")]
114
fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series {
115
match dtype {
116
DataType::Float64 => {
117
let values = (0..len)
118
.map(|v| std::cmp::min(window_size, v + 1) as f64)
119
.collect::<Vec<_>>();
120
Series::new(PlSmallStr::EMPTY, values)
121
},
122
DataType::Float32 => {
123
let values = (0..len)
124
.map(|v| std::cmp::min(window_size, v + 1) as f32)
125
.collect::<Vec<_>>();
126
Series::new(PlSmallStr::EMPTY, values)
127
},
128
#[cfg(feature = "dtype-f16")]
129
DataType::Float16 => {
130
use num_traits::AsPrimitive;
131
use polars_utils::float16::pf16;
132
let values = (0..len)
133
.map(|v| std::cmp::min(window_size, v + 1))
134
.map(AsPrimitive::<pf16>::as_)
135
.collect::<Vec<_>>();
136
Series::new(PlSmallStr::EMPTY, values)
137
},
138
_ => unreachable!(),
139
}
140
}
141
142
#[cfg(feature = "cov")]
143
pub(super) fn rolling_corr_cov(
144
s: &[Column],
145
rolling_options: RollingOptionsFixedWindow,
146
cov_options: RollingCovOptions,
147
is_corr: bool,
148
) -> PolarsResult<Column> {
149
let mut x = s[0].as_materialized_series().rechunk();
150
let mut y = s[1].as_materialized_series().rechunk();
151
152
if !x.dtype().is_float() {
153
x = x.cast(&DataType::Float64)?;
154
}
155
if !y.dtype().is_float() {
156
y = y.cast(&DataType::Float64)?;
157
}
158
let dtype = x.dtype().clone();
159
160
let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?;
161
let rolling_options_count = RollingOptionsFixedWindow {
162
window_size: rolling_options.window_size,
163
min_periods: 0,
164
..Default::default()
165
};
166
167
let count_x_y = if (x.null_count() + y.null_count()) > 0 {
168
// mask out nulls on both sides before compute mean/var
169
170
let valids = x.is_not_null().bitand(y.is_not_null());
171
let valids_arr = valids.downcast_as_array();
172
let valids_bitmap = valids_arr.values();
173
174
unsafe {
175
let xarr = &mut x.chunks_mut()[0];
176
*xarr = xarr.with_validity(Some(valids_bitmap.clone()));
177
let yarr = &mut y.chunks_mut()[0];
178
*yarr = yarr.with_validity(Some(valids_bitmap.clone()));
179
x.compute_len();
180
y.compute_len();
181
}
182
valids
183
.cast(&dtype)
184
.unwrap()
185
.rolling_sum(rolling_options_count)?
186
} else {
187
det_count_x_y(rolling_options.window_size, x.len(), &dtype)
188
};
189
190
let mean_x = x.rolling_mean(rolling_options.clone())?;
191
let mean_y = y.rolling_mean(rolling_options.clone())?;
192
let ddof = Series::new(
193
PlSmallStr::EMPTY,
194
&[AnyValue::from(cov_options.ddof).cast(&dtype)],
195
);
196
197
let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap()
198
* (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap())
199
.unwrap();
200
201
if is_corr {
202
let var_x = x.rolling_var(rolling_options.clone())?;
203
let var_y = y.rolling_var(rolling_options)?;
204
205
let base = (var_x * var_y).unwrap();
206
let sc = Scalar::new(
207
base.dtype().clone(),
208
AnyValue::Float64(0.5).cast(&dtype).into_static(),
209
);
210
let denominator = super::pow::pow(&mut [base.into_column(), sc.into_column("".into())])
211
.unwrap()
212
.take_materialized_series();
213
214
Ok((numerator / denominator)?.into_column())
215
} else {
216
Ok(numerator.into_column())
217
}
218
}
219
220
pub fn rolling_map(
221
c: &Column,
222
rolling_options: RollingOptionsFixedWindow,
223
f: PlanCallback<Series, Series>,
224
) -> PolarsResult<Column> {
225
c.as_materialized_series()
226
.rolling_map(
227
&(|s: &Series| f.call(s.clone())?.strict_cast(s.dtype())) as &_,
228
rolling_options,
229
)
230
.map(Column::from)
231
}
232
233