Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/shared.rs
8395 views
1
//! This module implements logic shared between nulls and no_nulls.
2
3
use arrow::array::{ArrayRef, PrimitiveArray};
4
use arrow::bitmap::MutableBitmap;
5
use arrow::trusted_len::TrustedLen;
6
use arrow::types::NativeType;
7
use bytemuck::allocation::zeroed_vec;
8
#[cfg(feature = "timezones")]
9
use chrono_tz::Tz;
10
use polars_compute::rolling::no_nulls::RollingAggWindowNoNulls;
11
use polars_compute::rolling::nulls::RollingAggWindowNulls;
12
use polars_core::prelude::*;
13
14
use crate::windows::duration::Duration;
15
use crate::windows::group_by::{ClosedWindow, group_by_values_iter};
16
17
pub(crate) trait RollingAggWindow<T: NativeType, Out: NativeType> {
18
/// # Safety
19
/// `start` and `end` must be in bounds of `slice` and associated structures.
20
unsafe fn update(&mut self, start: usize, end: usize);
21
22
/// Get the aggregate of the current window relative to the value at `idx`.
23
fn get_agg(&self, idx: usize) -> Option<Out>;
24
25
/// Returns the length of the underlying input.
26
fn slice_len(&self) -> usize;
27
}
28
29
#[repr(transparent)]
30
pub(crate) struct RollingAggWindowNoNullsWrapper<T>(pub T);
31
#[repr(transparent)]
32
pub(crate) struct RollingAggWindowNullsWrapper<T>(pub T);
33
34
impl<T: NativeType, Out: NativeType, Agg: RollingAggWindowNoNulls<T, Out>> RollingAggWindow<T, Out>
35
for RollingAggWindowNoNullsWrapper<Agg>
36
{
37
unsafe fn update(&mut self, start: usize, end: usize) {
38
// SAFETY: Caller MUST uphold function safety contract.
39
unsafe { self.0.update(start, end) }
40
}
41
42
fn get_agg(&self, idx: usize) -> Option<Out> {
43
self.0.get_agg(idx)
44
}
45
46
fn slice_len(&self) -> usize {
47
self.0.slice_len()
48
}
49
}
50
51
impl<T: NativeType, Out: NativeType, Agg: RollingAggWindowNulls<T, Out>> RollingAggWindow<T, Out>
52
for RollingAggWindowNullsWrapper<Agg>
53
{
54
unsafe fn update(&mut self, start: usize, end: usize) {
55
// SAFETY: Caller MUST uphold function safety contract.
56
unsafe { self.0.update(start, end) }
57
}
58
59
fn get_agg(&self, idx: usize) -> Option<Out> {
60
self.0.get_agg(idx)
61
}
62
63
fn slice_len(&self) -> usize {
64
self.0.slice_len()
65
}
66
}
67
68
#[expect(clippy::too_many_arguments)]
69
pub(crate) fn rolling_apply_agg<T, Out, Agg>(
70
agg_window: &mut Agg,
71
period: Duration,
72
time: &[i64],
73
closed_window: ClosedWindow,
74
min_periods: usize,
75
tu: TimeUnit,
76
tz: Option<&TimeZone>,
77
sorting_indices: Option<&[IdxSize]>,
78
) -> PolarsResult<ArrayRef>
79
where
80
T: NativeType,
81
Out: NativeType,
82
Agg: RollingAggWindow<T, Out>,
83
{
84
let offset_iter = match tz {
85
#[cfg(feature = "timezones")]
86
Some(tz) => group_by_values_iter(period, time, closed_window, tu, tz.parse::<Tz>().ok()),
87
_ => group_by_values_iter(period, time, closed_window, tu, None),
88
}?;
89
90
if let Some(indices) = sorting_indices {
91
rolling_apply_agg_window(agg_window, offset_iter, min_periods, indices)
92
} else {
93
rolling_apply_agg_window_sorted(agg_window, offset_iter, min_periods)
94
}
95
}
96
97
// Use an aggregation window that maintains the state.
98
// Fastpath if values were known to already be sorted by time.
99
fn rolling_apply_agg_window_sorted<Agg, O, T, Out>(
100
agg_window: &mut Agg,
101
offsets: O,
102
min_periods: usize,
103
) -> PolarsResult<ArrayRef>
104
where
105
Agg: RollingAggWindow<T, Out>,
106
O: Iterator<Item = PolarsResult<(IdxSize, IdxSize)>> + TrustedLen,
107
T: NativeType,
108
Out: NativeType,
109
{
110
let out = offsets
111
.enumerate()
112
.map(|(idx, result)| {
113
result.map(|(start, len)| {
114
let end = start + len;
115
116
// On the Python side, if `min_periods` wasn't specified, it is set to
117
// `1`. In that case, this condition is the same as checking
118
// `if start == end`.
119
if len < (min_periods as IdxSize) {
120
None
121
} else {
122
// SAFETY: we are in bounds
123
unsafe { agg_window.update(start as usize, end as usize) }
124
agg_window.get_agg(idx)
125
}
126
})
127
})
128
.collect::<PolarsResult<PrimitiveArray<Out>>>()?;
129
130
Ok(Box::new(out))
131
}
132
133
// Use an aggregation window that maintains the state
134
fn rolling_apply_agg_window<Agg, O, T, Out>(
135
agg_window: &mut Agg,
136
offsets: O,
137
min_periods: usize,
138
sorting_indices: &[IdxSize],
139
) -> PolarsResult<ArrayRef>
140
where
141
Agg: RollingAggWindow<T, Out>,
142
O: Iterator<Item = PolarsResult<(IdxSize, IdxSize)>> + TrustedLen,
143
T: NativeType,
144
Out: NativeType,
145
{
146
let mut out = zeroed_vec(agg_window.slice_len());
147
let mut validity: Option<MutableBitmap> = None;
148
offsets.enumerate().try_for_each(|(idx, result)| {
149
let (start, len) = result?;
150
let end = start + len;
151
let out_idx = unsafe { sorting_indices.get_unchecked(idx) };
152
153
// On the Python side, if `min_periods` wasn't specified, it is set to
154
// `1`. In that case, this condition is the same as checking
155
// `if start == end`.
156
if len >= (min_periods as IdxSize) {
157
// SAFETY:
158
// we are in bound
159
unsafe { agg_window.update(start as usize, end as usize) };
160
let res = agg_window.get_agg(*out_idx as usize);
161
162
if let Some(res) = res {
163
// SAFETY: `idx` is in bounds because `sorting_indices` was just taken from
164
// `by`, which has already been checked to be the same length as the values.
165
unsafe { *out.get_unchecked_mut(*out_idx as usize) = res };
166
} else {
167
instantiate_bitmap_if_null_and_set_false_at_idx(
168
&mut validity,
169
agg_window.slice_len(),
170
*out_idx as usize,
171
)
172
}
173
} else {
174
instantiate_bitmap_if_null_and_set_false_at_idx(
175
&mut validity,
176
agg_window.slice_len(),
177
*out_idx as usize,
178
)
179
}
180
Ok::<(), PolarsError>(())
181
})?;
182
183
let out = PrimitiveArray::<Out>::from_vec(out).with_validity(validity.map(|x| x.into()));
184
185
Ok(Box::new(out))
186
}
187
188
// Instantiate a bitmap when the first null value is encountered.
189
// Set the validity at index `idx` to `false`.
190
fn instantiate_bitmap_if_null_and_set_false_at_idx(
191
validity: &mut Option<MutableBitmap>,
192
len: usize,
193
idx: usize,
194
) {
195
let bitmap = validity.get_or_insert_with(|| {
196
let mut bitmap = MutableBitmap::with_capacity(len);
197
bitmap.extend_constant(len, true);
198
bitmap
199
});
200
bitmap.set(idx, false);
201
}
202
203