Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/ewm_by.rs
6939 views
1
use bytemuck::allocation::zeroed_vec;
2
use num_traits::{Float, FromPrimitive, One, Zero};
3
use polars_core::prelude::*;
4
use polars_core::utils::binary_concatenate_validities;
5
6
pub fn ewm_mean_by(
7
s: &Series,
8
times: &Series,
9
half_life: i64,
10
times_is_sorted: bool,
11
) -> PolarsResult<Series> {
12
fn func<T>(
13
values: &ChunkedArray<T>,
14
times: &Int64Chunked,
15
half_life: i64,
16
times_is_sorted: bool,
17
) -> PolarsResult<Series>
18
where
19
T: PolarsFloatType,
20
T::Native: Float + Zero + One,
21
{
22
if times_is_sorted {
23
Ok(ewm_mean_by_impl_sorted(values, times, half_life).into_series())
24
} else {
25
Ok(ewm_mean_by_impl(values, times, half_life).into_series())
26
}
27
}
28
29
polars_ensure!(
30
s.len() == times.len(),
31
length_mismatch = "ewm_mean_by",
32
s.len(),
33
times.len()
34
);
35
36
match (s.dtype(), times.dtype()) {
37
(DataType::Float64, DataType::Int64) => func(
38
s.f64().unwrap(),
39
times.i64().unwrap(),
40
half_life,
41
times_is_sorted,
42
),
43
(DataType::Float32, DataType::Int64) => func(
44
s.f32().unwrap(),
45
times.i64().unwrap(),
46
half_life,
47
times_is_sorted,
48
),
49
#[cfg(feature = "dtype-datetime")]
50
(_, DataType::Datetime(time_unit, _)) => {
51
let half_life = adjust_half_life_to_time_unit(half_life, time_unit);
52
ewm_mean_by(
53
s,
54
&times.cast(&DataType::Int64)?,
55
half_life,
56
times_is_sorted,
57
)
58
},
59
#[cfg(feature = "dtype-date")]
60
(_, DataType::Date) => ewm_mean_by(
61
s,
62
&times.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,
63
half_life,
64
times_is_sorted,
65
),
66
(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => ewm_mean_by(
67
s,
68
&times.cast(&DataType::Int64)?,
69
half_life,
70
times_is_sorted,
71
),
72
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
73
ewm_mean_by(
74
&s.cast(&DataType::Float64)?,
75
times,
76
half_life,
77
times_is_sorted,
78
)
79
},
80
_ => {
81
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \
82
Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \
83
UInt64, or UInt32")
84
},
85
}
86
}
87
88
/// Sort on behalf of user
89
fn ewm_mean_by_impl<T>(
90
values: &ChunkedArray<T>,
91
times: &Int64Chunked,
92
half_life: i64,
93
) -> ChunkedArray<T>
94
where
95
T: PolarsFloatType,
96
T::Native: Float + Zero + One,
97
ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
98
{
99
let sorting_indices = times.arg_sort(Default::default());
100
let sorted_values = unsafe { values.take_unchecked(&sorting_indices) };
101
let sorted_times = unsafe { times.take_unchecked(&sorting_indices) };
102
let sorting_indices = sorting_indices
103
.cont_slice()
104
.expect("`arg_sort` should have returned a single chunk");
105
106
let mut out: Vec<_> = zeroed_vec(sorted_times.len());
107
108
let mut skip_rows: usize = 0;
109
let mut prev_time: i64 = 0;
110
let mut prev_result = T::Native::zero();
111
for (idx, (value, time)) in sorted_values.iter().zip(sorted_times.iter()).enumerate() {
112
if let (Some(time), Some(value)) = (time, value) {
113
prev_time = time;
114
prev_result = value;
115
unsafe {
116
let out_idx = sorting_indices.get_unchecked(idx);
117
*out.get_unchecked_mut(*out_idx as usize) = prev_result;
118
}
119
skip_rows = idx + 1;
120
break;
121
};
122
}
123
sorted_values
124
.iter()
125
.zip(sorted_times.iter())
126
.enumerate()
127
.skip(skip_rows)
128
.for_each(|(idx, (value, time))| {
129
if let (Some(time), Some(value)) = (time, value) {
130
let result = update(value, prev_result, time, prev_time, half_life);
131
prev_time = time;
132
prev_result = result;
133
unsafe {
134
let out_idx = sorting_indices.get_unchecked(idx);
135
*out.get_unchecked_mut(*out_idx as usize) = result;
136
}
137
};
138
});
139
let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest()));
140
if (times.null_count() > 0) || (values.null_count() > 0) {
141
let validity = binary_concatenate_validities(times, values);
142
arr = arr.with_validity_typed(validity);
143
}
144
ChunkedArray::with_chunk(values.name().clone(), arr)
145
}
146
147
/// Fastpath if `times` is known to already be sorted.
148
fn ewm_mean_by_impl_sorted<T>(
149
values: &ChunkedArray<T>,
150
times: &Int64Chunked,
151
half_life: i64,
152
) -> ChunkedArray<T>
153
where
154
T: PolarsFloatType,
155
T::Native: Float + Zero + One,
156
{
157
let mut out: Vec<_> = zeroed_vec(times.len());
158
159
let mut skip_rows: usize = 0;
160
let mut prev_time: i64 = 0;
161
let mut prev_result = T::Native::zero();
162
for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {
163
if let (Some(time), Some(value)) = (time, value) {
164
prev_time = time;
165
prev_result = value;
166
unsafe {
167
*out.get_unchecked_mut(idx) = prev_result;
168
}
169
skip_rows = idx + 1;
170
break;
171
}
172
}
173
values
174
.iter()
175
.zip(times.iter())
176
.enumerate()
177
.skip(skip_rows)
178
.for_each(|(idx, (value, time))| {
179
if let (Some(time), Some(value)) = (time, value) {
180
let result = update(value, prev_result, time, prev_time, half_life);
181
prev_time = time;
182
prev_result = result;
183
unsafe {
184
*out.get_unchecked_mut(idx) = result;
185
}
186
};
187
});
188
let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest()));
189
if (times.null_count() > 0) || (values.null_count() > 0) {
190
let validity = binary_concatenate_validities(times, values);
191
arr = arr.with_validity_typed(validity);
192
}
193
ChunkedArray::with_chunk(values.name().clone(), arr)
194
}
195
196
fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 {
197
match time_unit {
198
TimeUnit::Milliseconds => half_life / 1_000_000,
199
TimeUnit::Microseconds => half_life / 1_000,
200
TimeUnit::Nanoseconds => half_life,
201
}
202
}
203
204
fn update<T>(value: T, prev_result: T, time: i64, prev_time: i64, half_life: i64) -> T
205
where
206
T: Float + Zero + One + FromPrimitive,
207
{
208
if value != prev_result {
209
let delta_time = time - prev_time;
210
// equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life)
211
let one_minus_alpha = T::from_f64(0.5)
212
.unwrap()
213
.powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap());
214
let alpha = T::one() - one_minus_alpha;
215
alpha * value + one_minus_alpha * prev_result
216
} else {
217
value
218
}
219
}
220
221