Path: blob/main/crates/polars-ops/src/series/ops/ewm_by.rs
8500 views
use bytemuck::allocation::zeroed_vec;1use num_traits::{Float, FromPrimitive, One, Zero};2use polars_core::prelude::*;3use polars_core::utils::binary_concatenate_validities;45pub fn ewm_mean_by(6s: &Series,7times: &Series,8half_life: i64,9times_is_sorted: bool,10) -> PolarsResult<Series> {11fn func<T>(12values: &ChunkedArray<T>,13times: &Int64Chunked,14half_life: i64,15times_is_sorted: bool,16) -> PolarsResult<Series>17where18T: PolarsFloatType,19T::Native: Float + Zero + One,20{21if times_is_sorted {22Ok(ewm_mean_by_impl_sorted(values, times, half_life).into_series())23} else {24Ok(ewm_mean_by_impl(values, times, half_life).into_series())25}26}2728polars_ensure!(29s.len() == times.len(),30length_mismatch = "ewm_mean_by",31s.len(),32times.len()33);3435match (s.dtype(), times.dtype()) {36(DataType::Float64, DataType::Int64) => func(37s.f64().unwrap(),38times.i64().unwrap(),39half_life,40times_is_sorted,41),42(DataType::Float32, DataType::Int64) => func(43s.f32().unwrap(),44times.i64().unwrap(),45half_life,46times_is_sorted,47),48#[cfg(feature = "dtype-f16")]49(DataType::Float16, DataType::Int64) => func(50s.f16().unwrap(),51times.i64().unwrap(),52half_life,53times_is_sorted,54),55#[cfg(feature = "dtype-datetime")]56(_, DataType::Datetime(time_unit, _)) => {57let half_life = adjust_half_life_to_time_unit(half_life, time_unit);58ewm_mean_by(59s,60×.cast(&DataType::Int64)?,61half_life,62times_is_sorted,63)64},65#[cfg(feature = "dtype-date")]66(_, DataType::Date) => ewm_mean_by(67s,68×.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,69half_life,70times_is_sorted,71),72(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => ewm_mean_by(73s,74×.cast(&DataType::Int64)?,75half_life,76times_is_sorted,77),78(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {79ewm_mean_by(80&s.cast(&DataType::Float64)?,81times,82half_life,83times_is_sorted,84)85},86_ => {87polars_bail!(InvalidOperation: "expected series to be Float64, Float32, Float16, \88Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \89UInt64, or UInt32")90},91}92}9394/// Sort on behalf of user95fn ewm_mean_by_impl<T>(96values: &ChunkedArray<T>,97times: &Int64Chunked,98half_life: i64,99) -> ChunkedArray<T>100where101T: PolarsFloatType,102T::Native: Float + Zero + One,103ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,104{105let sorting_indices = times.arg_sort(Default::default());106let sorted_values = unsafe { values.take_unchecked(&sorting_indices) };107let sorted_times = unsafe { times.take_unchecked(&sorting_indices) };108let sorting_indices = sorting_indices109.cont_slice()110.expect("`arg_sort` should have returned a single chunk");111112let mut out: Vec<_> = zeroed_vec(sorted_times.len());113114let mut skip_rows: usize = 0;115let mut prev_time: i64 = 0;116let mut prev_result = T::Native::zero();117for (idx, (value, time)) in sorted_values.iter().zip(sorted_times.iter()).enumerate() {118if let (Some(time), Some(value)) = (time, value) {119prev_time = time;120prev_result = value;121unsafe {122let out_idx = sorting_indices.get_unchecked(idx);123*out.get_unchecked_mut(*out_idx as usize) = prev_result;124}125skip_rows = idx + 1;126break;127};128}129sorted_values130.iter()131.zip(sorted_times.iter())132.enumerate()133.skip(skip_rows)134.for_each(|(idx, (value, time))| {135if let (Some(time), Some(value)) = (time, value) {136let result = update(value, prev_result, time, prev_time, half_life);137prev_time = time;138prev_result = result;139unsafe {140let out_idx = sorting_indices.get_unchecked(idx);141*out.get_unchecked_mut(*out_idx as usize) = result;142}143};144});145let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest()));146if (times.null_count() > 0) || (values.null_count() > 0) {147let validity = binary_concatenate_validities(times, values);148arr = arr.with_validity_typed(validity);149}150ChunkedArray::with_chunk(values.name().clone(), arr)151}152153/// Fastpath if `times` is known to already be sorted.154fn ewm_mean_by_impl_sorted<T>(155values: &ChunkedArray<T>,156times: &Int64Chunked,157half_life: i64,158) -> ChunkedArray<T>159where160T: PolarsFloatType,161T::Native: Float + Zero + One,162{163let mut out: Vec<_> = zeroed_vec(times.len());164165let mut skip_rows: usize = 0;166let mut prev_time: i64 = 0;167let mut prev_result = T::Native::zero();168for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {169if let (Some(time), Some(value)) = (time, value) {170prev_time = time;171prev_result = value;172unsafe {173*out.get_unchecked_mut(idx) = prev_result;174}175skip_rows = idx + 1;176break;177}178}179values180.iter()181.zip(times.iter())182.enumerate()183.skip(skip_rows)184.for_each(|(idx, (value, time))| {185if let (Some(time), Some(value)) = (time, value) {186let result = update(value, prev_result, time, prev_time, half_life);187prev_time = time;188prev_result = result;189unsafe {190*out.get_unchecked_mut(idx) = result;191}192};193});194let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest()));195if (times.null_count() > 0) || (values.null_count() > 0) {196let validity = binary_concatenate_validities(times, values);197arr = arr.with_validity_typed(validity);198}199ChunkedArray::with_chunk(values.name().clone(), arr)200}201202fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 {203match time_unit {204TimeUnit::Milliseconds => half_life / 1_000_000,205TimeUnit::Microseconds => half_life / 1_000,206TimeUnit::Nanoseconds => half_life,207}208}209210fn update<T>(value: T, prev_result: T, time: i64, prev_time: i64, half_life: i64) -> T211where212T: Float + Zero + One + FromPrimitive,213{214if value != prev_result {215let delta_time = time - prev_time;216// equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life)217let one_minus_alpha = T::from_f64(0.5)218.unwrap()219.powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap());220let alpha = T::one() - one_minus_alpha;221alpha * value + one_minus_alpha * prev_result222} else {223value224}225}226227228