Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/rolling.rs
8424 views
1
use arrow::array::PrimitiveArray;
2
use polars_time::prelude::RollingWindower;
3
use polars_time::{ClosedWindow, Duration, PolarsTemporalGroupby, RollingGroupOptions};
4
use polars_utils::UnitVec;
5
6
use super::*;
7
8
pub(crate) struct RollingExpr {
9
/// the root column that the Function will be applied on.
10
/// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index
11
/// TODO! support keys?
12
/// The challenge is that the group_by will reorder the results and the
13
/// keys, and time index would need to be updated, or the result should be joined back
14
/// For now, don't support it.
15
///
16
/// A function Expr. i.e. Mean, Median, Max, etc.
17
pub(crate) phys_function: Arc<dyn PhysicalExpr>,
18
pub(crate) index_column: Arc<dyn PhysicalExpr>,
19
pub(crate) period: Duration,
20
pub(crate) offset: Duration,
21
pub(crate) closed_window: ClosedWindow,
22
pub(crate) expr: Expr,
23
pub(crate) output_field: Field,
24
}
25
26
impl PhysicalExpr for RollingExpr {
27
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
28
let groups = if let Some(index_column_name) = self.index_column.as_column() {
29
let options = RollingGroupOptions {
30
index_column: index_column_name.clone(),
31
period: self.period,
32
offset: self.offset,
33
closed_window: self.closed_window,
34
};
35
let groups_key = format!("{options:?}");
36
let groups = {
37
// Groups must be set by expression runner.
38
state.window_cache.get_groups(&groups_key)
39
};
40
41
// There can be multiple rolling expressions in a single expr.
42
// E.g. `min().rolling() + max().rolling()`
43
// So if we hit that we will compute them here.
44
match groups {
45
Some(groups) => groups,
46
None => {
47
let (_time_key, groups) = df.rolling(None, &options)?;
48
state.window_cache.insert_groups(groups_key, groups.clone());
49
groups
50
},
51
}
52
} else {
53
let index_column_name = PlSmallStr::from_static("__PL_INDEX_COL");
54
let options = RollingGroupOptions {
55
index_column: index_column_name.clone(),
56
period: self.period,
57
offset: self.offset,
58
closed_window: self.closed_window,
59
};
60
61
let index_column = self.index_column.evaluate(df, state)?;
62
63
let mut df = df.clone();
64
df.with_column(index_column.with_name(index_column_name))?;
65
let (_time_key, groups) = df.rolling(None, &options)?;
66
groups
67
};
68
69
let out = self
70
.phys_function
71
.evaluate_on_groups(df, &groups, state)?
72
.finalize();
73
polars_ensure!(out.len() == groups.len(), agg_len = out.len(), groups.len());
74
Ok(out.into_column())
75
}
76
77
fn evaluate_on_groups<'a>(
78
&self,
79
df: &DataFrame,
80
groups: &'a GroupPositions,
81
state: &ExecutionState,
82
) -> PolarsResult<AggregationContext<'a>> {
83
let mut index_column = self.index_column.evaluate_on_groups(df, groups, state)?;
84
85
index_column.groups();
86
87
let mut index_column_data = index_column.flat_naive();
88
use DataType as DT;
89
let (time_unit, time_zone): (TimeUnit, Option<TimeZone>) = match index_column_data.dtype() {
90
DT::Datetime(tu, tz) => (*tu, tz.clone()),
91
DT::Date => (TimeUnit::Microseconds, None),
92
DT::UInt32 | DT::UInt64 | DT::Int32 => {
93
index_column_data = Cow::Owned(index_column_data.cast(&DT::Int64)?);
94
(TimeUnit::Nanoseconds, None)
95
},
96
DT::Int64 => (TimeUnit::Nanoseconds, None),
97
dt => polars_bail!(
98
ComputeError:
99
"expected any of the following dtypes: {{ Date, Datetime, Int32, Int64, UInt32, UInt64 }}, got {}",
100
dt
101
),
102
};
103
let index_column_data =
104
index_column_data.cast(&DataType::Datetime(time_unit, time_zone.clone()))?;
105
106
// @NOTE: This is a bit strange since it ignores errors, but it mirrors the in-memory
107
// engine.
108
let tz = time_zone.and_then(|tz| tz.parse::<chrono_tz::Tz>().ok());
109
110
polars_ensure!(
111
index_column_data.null_count() == 0,
112
ComputeError: "null values in `rolling` not supported, fill nulls."
113
);
114
let index_column_data = index_column_data.rechunk_to_arrow(CompatLevel::newest());
115
let index_column_data = index_column_data
116
.as_any()
117
.downcast_ref::<PrimitiveArray<i64>>()
118
.unwrap();
119
let mut index_column_data = Cow::Borrowed(index_column_data.values().as_slice());
120
let mut rolling =
121
RollingWindower::new(self.period, self.offset, self.closed_window, time_unit, tz);
122
123
let num_elements = groups.num_elements();
124
125
// Convert the index groups to slices.
126
//
127
// This is not strictly necessary but allows us to reuse the existing `RollingWindower`
128
// struct.
129
let (slice_groups, overlapping, monotonic) = match &**index_column.groups {
130
GroupsType::Idx(idx) => {
131
let mut data = Vec::with_capacity(num_elements);
132
let mut slices = Vec::with_capacity(groups.len());
133
for i in idx.all() {
134
slices.push([data.len() as IdxSize, i.len() as IdxSize]);
135
data.extend(i.iter().map(|i| index_column_data[*i as usize]));
136
}
137
index_column_data = Cow::Owned(data);
138
(Cow::Owned(slices), false, true)
139
},
140
GroupsType::Slice {
141
groups,
142
overlapping,
143
monotonic,
144
} => (Cow::Borrowed(groups), *overlapping, *monotonic),
145
};
146
147
// We need to make sure there are no length mismatches, otherwise we will have problems
148
// down the line.
149
assert_eq!(slice_groups.len(), groups.len());
150
let length_mismatch = match &**groups {
151
GroupsType::Idx(idx) => idx
152
.all()
153
.iter()
154
.zip(slice_groups.iter())
155
.map(|(i, [_, s])| (i.len(), *s as usize))
156
.find(|(l, r)| *l != *r),
157
GroupsType::Slice {
158
groups,
159
overlapping: _,
160
monotonic: _,
161
} => groups
162
.iter()
163
.zip(slice_groups.iter())
164
.map(|([_, s1], [_, s2])| (*s1 as usize, *s2 as usize))
165
.find(|(l, r)| *l != *r),
166
};
167
if let Some((l, r)) = length_mismatch {
168
polars_bail!(length_mismatch = "rolling", l, r);
169
}
170
171
// Get the subslices within each group.
172
let mut windows = Vec::with_capacity(num_elements);
173
for [start, length] in slice_groups.as_ref() {
174
rolling.reset();
175
let time = &index_column_data[*start as usize..][..*length as usize];
176
let offset = rolling.insert(&[time], &mut windows)?;
177
let time = &time[offset as usize..];
178
rolling.finalize(&[time], &mut windows);
179
}
180
181
// Create new groups as subgroups of the existing groups.
182
let nested_groups = match &**groups {
183
GroupsType::Idx(idx) => {
184
let mut nested_groups = Vec::with_capacity(num_elements);
185
let mut i = 0;
186
for idx in idx.all() {
187
nested_groups.extend(windows[i..][..idx.len()].iter().map(|[s, l]| {
188
(
189
idx[*s as usize],
190
UnitVec::from_iter(idx[*s as usize..][..*l as usize].iter().copied()),
191
)
192
}));
193
i += idx.len();
194
}
195
GroupsType::Idx(nested_groups.into())
196
},
197
GroupsType::Slice {
198
groups,
199
overlapping: _,
200
monotonic,
201
} => {
202
let mut nested_groups = Vec::with_capacity(num_elements);
203
let mut i = 0;
204
for [start, length] in groups {
205
nested_groups.extend(
206
windows[i..][..*length as usize]
207
.iter()
208
.map(|[s, l]| [*start + *s, *l]),
209
);
210
i += *length as usize;
211
}
212
GroupsType::new_slice(nested_groups, true, *monotonic)
213
},
214
};
215
216
let nested_groups = nested_groups.into_sliceable();
217
let out = self
218
.phys_function
219
.evaluate_on_groups(df, &nested_groups, state)?
220
.finalize();
221
polars_ensure!(
222
out.len() == nested_groups.len(),
223
agg_len = out.len(),
224
nested_groups.len()
225
);
226
227
let out = AggregationContext {
228
state: AggState::NotAggregated(out.into_column()),
229
groups: Cow::Owned(
230
GroupsType::new_slice(slice_groups.into_owned(), overlapping, monotonic)
231
.into_sliceable(),
232
),
233
update_groups: UpdateGroups::No,
234
original_len: false,
235
};
236
Ok(out)
237
}
238
239
fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
240
Ok(self.output_field.clone())
241
}
242
243
fn as_expression(&self) -> Option<&Expr> {
244
Some(&self.expr)
245
}
246
247
fn is_scalar(&self) -> bool {
248
false
249
}
250
}
251
252