Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/ewm_by.rs
8500 views
1
use bytemuck::allocation::zeroed_vec;
2
use num_traits::{Float, FromPrimitive, One, Zero};
3
use polars_core::prelude::*;
4
use polars_core::utils::binary_concatenate_validities;
5
6
pub fn ewm_mean_by(
7
s: &Series,
8
times: &Series,
9
half_life: i64,
10
times_is_sorted: bool,
11
) -> PolarsResult<Series> {
12
fn func<T>(
13
values: &ChunkedArray<T>,
14
times: &Int64Chunked,
15
half_life: i64,
16
times_is_sorted: bool,
17
) -> PolarsResult<Series>
18
where
19
T: PolarsFloatType,
20
T::Native: Float + Zero + One,
21
{
22
if times_is_sorted {
23
Ok(ewm_mean_by_impl_sorted(values, times, half_life).into_series())
24
} else {
25
Ok(ewm_mean_by_impl(values, times, half_life).into_series())
26
}
27
}
28
29
polars_ensure!(
30
s.len() == times.len(),
31
length_mismatch = "ewm_mean_by",
32
s.len(),
33
times.len()
34
);
35
36
match (s.dtype(), times.dtype()) {
37
(DataType::Float64, DataType::Int64) => func(
38
s.f64().unwrap(),
39
times.i64().unwrap(),
40
half_life,
41
times_is_sorted,
42
),
43
(DataType::Float32, DataType::Int64) => func(
44
s.f32().unwrap(),
45
times.i64().unwrap(),
46
half_life,
47
times_is_sorted,
48
),
49
#[cfg(feature = "dtype-f16")]
50
(DataType::Float16, DataType::Int64) => func(
51
s.f16().unwrap(),
52
times.i64().unwrap(),
53
half_life,
54
times_is_sorted,
55
),
56
#[cfg(feature = "dtype-datetime")]
57
(_, DataType::Datetime(time_unit, _)) => {
58
let half_life = adjust_half_life_to_time_unit(half_life, time_unit);
59
ewm_mean_by(
60
s,
61
&times.cast(&DataType::Int64)?,
62
half_life,
63
times_is_sorted,
64
)
65
},
66
#[cfg(feature = "dtype-date")]
67
(_, DataType::Date) => ewm_mean_by(
68
s,
69
&times.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,
70
half_life,
71
times_is_sorted,
72
),
73
(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => ewm_mean_by(
74
s,
75
&times.cast(&DataType::Int64)?,
76
half_life,
77
times_is_sorted,
78
),
79
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
80
ewm_mean_by(
81
&s.cast(&DataType::Float64)?,
82
times,
83
half_life,
84
times_is_sorted,
85
)
86
},
87
_ => {
88
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, Float16, \
89
Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \
90
UInt64, or UInt32")
91
},
92
}
93
}
94
95
/// Sort on behalf of user
96
fn ewm_mean_by_impl<T>(
97
values: &ChunkedArray<T>,
98
times: &Int64Chunked,
99
half_life: i64,
100
) -> ChunkedArray<T>
101
where
102
T: PolarsFloatType,
103
T::Native: Float + Zero + One,
104
ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
105
{
106
let sorting_indices = times.arg_sort(Default::default());
107
let sorted_values = unsafe { values.take_unchecked(&sorting_indices) };
108
let sorted_times = unsafe { times.take_unchecked(&sorting_indices) };
109
let sorting_indices = sorting_indices
110
.cont_slice()
111
.expect("`arg_sort` should have returned a single chunk");
112
113
let mut out: Vec<_> = zeroed_vec(sorted_times.len());
114
115
let mut skip_rows: usize = 0;
116
let mut prev_time: i64 = 0;
117
let mut prev_result = T::Native::zero();
118
for (idx, (value, time)) in sorted_values.iter().zip(sorted_times.iter()).enumerate() {
119
if let (Some(time), Some(value)) = (time, value) {
120
prev_time = time;
121
prev_result = value;
122
unsafe {
123
let out_idx = sorting_indices.get_unchecked(idx);
124
*out.get_unchecked_mut(*out_idx as usize) = prev_result;
125
}
126
skip_rows = idx + 1;
127
break;
128
};
129
}
130
sorted_values
131
.iter()
132
.zip(sorted_times.iter())
133
.enumerate()
134
.skip(skip_rows)
135
.for_each(|(idx, (value, time))| {
136
if let (Some(time), Some(value)) = (time, value) {
137
let result = update(value, prev_result, time, prev_time, half_life);
138
prev_time = time;
139
prev_result = result;
140
unsafe {
141
let out_idx = sorting_indices.get_unchecked(idx);
142
*out.get_unchecked_mut(*out_idx as usize) = result;
143
}
144
};
145
});
146
let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest()));
147
if (times.null_count() > 0) || (values.null_count() > 0) {
148
let validity = binary_concatenate_validities(times, values);
149
arr = arr.with_validity_typed(validity);
150
}
151
ChunkedArray::with_chunk(values.name().clone(), arr)
152
}
153
154
/// Fastpath if `times` is known to already be sorted.
155
fn ewm_mean_by_impl_sorted<T>(
156
values: &ChunkedArray<T>,
157
times: &Int64Chunked,
158
half_life: i64,
159
) -> ChunkedArray<T>
160
where
161
T: PolarsFloatType,
162
T::Native: Float + Zero + One,
163
{
164
let mut out: Vec<_> = zeroed_vec(times.len());
165
166
let mut skip_rows: usize = 0;
167
let mut prev_time: i64 = 0;
168
let mut prev_result = T::Native::zero();
169
for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {
170
if let (Some(time), Some(value)) = (time, value) {
171
prev_time = time;
172
prev_result = value;
173
unsafe {
174
*out.get_unchecked_mut(idx) = prev_result;
175
}
176
skip_rows = idx + 1;
177
break;
178
}
179
}
180
values
181
.iter()
182
.zip(times.iter())
183
.enumerate()
184
.skip(skip_rows)
185
.for_each(|(idx, (value, time))| {
186
if let (Some(time), Some(value)) = (time, value) {
187
let result = update(value, prev_result, time, prev_time, half_life);
188
prev_time = time;
189
prev_result = result;
190
unsafe {
191
*out.get_unchecked_mut(idx) = result;
192
}
193
};
194
});
195
let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest()));
196
if (times.null_count() > 0) || (values.null_count() > 0) {
197
let validity = binary_concatenate_validities(times, values);
198
arr = arr.with_validity_typed(validity);
199
}
200
ChunkedArray::with_chunk(values.name().clone(), arr)
201
}
202
203
fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 {
204
match time_unit {
205
TimeUnit::Milliseconds => half_life / 1_000_000,
206
TimeUnit::Microseconds => half_life / 1_000,
207
TimeUnit::Nanoseconds => half_life,
208
}
209
}
210
211
fn update<T>(value: T, prev_result: T, time: i64, prev_time: i64, half_life: i64) -> T
212
where
213
T: Float + Zero + One + FromPrimitive,
214
{
215
if value != prev_result {
216
let delta_time = time - prev_time;
217
// equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life)
218
let one_minus_alpha = T::from_f64(0.5)
219
.unwrap()
220
.powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap());
221
let alpha = T::one() - one_minus_alpha;
222
alpha * value + one_minus_alpha * prev_result
223
} else {
224
value
225
}
226
}
227
228