Path: blob/main/crates/polars-ops/src/series/ops/ewm_by.rs
6939 views
use bytemuck::allocation::zeroed_vec;1use num_traits::{Float, FromPrimitive, One, Zero};2use polars_core::prelude::*;3use polars_core::utils::binary_concatenate_validities;45pub fn ewm_mean_by(6s: &Series,7times: &Series,8half_life: i64,9times_is_sorted: bool,10) -> PolarsResult<Series> {11fn func<T>(12values: &ChunkedArray<T>,13times: &Int64Chunked,14half_life: i64,15times_is_sorted: bool,16) -> PolarsResult<Series>17where18T: PolarsFloatType,19T::Native: Float + Zero + One,20{21if times_is_sorted {22Ok(ewm_mean_by_impl_sorted(values, times, half_life).into_series())23} else {24Ok(ewm_mean_by_impl(values, times, half_life).into_series())25}26}2728polars_ensure!(29s.len() == times.len(),30length_mismatch = "ewm_mean_by",31s.len(),32times.len()33);3435match (s.dtype(), times.dtype()) {36(DataType::Float64, DataType::Int64) => func(37s.f64().unwrap(),38times.i64().unwrap(),39half_life,40times_is_sorted,41),42(DataType::Float32, DataType::Int64) => func(43s.f32().unwrap(),44times.i64().unwrap(),45half_life,46times_is_sorted,47),48#[cfg(feature = "dtype-datetime")]49(_, DataType::Datetime(time_unit, _)) => {50let half_life = adjust_half_life_to_time_unit(half_life, time_unit);51ewm_mean_by(52s,53×.cast(&DataType::Int64)?,54half_life,55times_is_sorted,56)57},58#[cfg(feature = "dtype-date")]59(_, DataType::Date) => ewm_mean_by(60s,61×.cast(&DataType::Datetime(TimeUnit::Microseconds, None))?,62half_life,63times_is_sorted,64),65(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => ewm_mean_by(66s,67×.cast(&DataType::Int64)?,68half_life,69times_is_sorted,70),71(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {72ewm_mean_by(73&s.cast(&DataType::Float64)?,74times,75half_life,76times_is_sorted,77)78},79_ => {80polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \81Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \82UInt64, or UInt32")83},84}85}8687/// Sort on behalf of user88fn ewm_mean_by_impl<T>(89values: &ChunkedArray<T>,90times: &Int64Chunked,91half_life: i64,92) -> ChunkedArray<T>93where94T: PolarsFloatType,95T::Native: Float + Zero + One,96ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,97{98let sorting_indices = times.arg_sort(Default::default());99let sorted_values = unsafe { values.take_unchecked(&sorting_indices) };100let sorted_times = unsafe { times.take_unchecked(&sorting_indices) };101let sorting_indices = sorting_indices102.cont_slice()103.expect("`arg_sort` should have returned a single chunk");104105let mut out: Vec<_> = zeroed_vec(sorted_times.len());106107let mut skip_rows: usize = 0;108let mut prev_time: i64 = 0;109let mut prev_result = T::Native::zero();110for (idx, (value, time)) in sorted_values.iter().zip(sorted_times.iter()).enumerate() {111if let (Some(time), Some(value)) = (time, value) {112prev_time = time;113prev_result = value;114unsafe {115let out_idx = sorting_indices.get_unchecked(idx);116*out.get_unchecked_mut(*out_idx as usize) = prev_result;117}118skip_rows = idx + 1;119break;120};121}122sorted_values123.iter()124.zip(sorted_times.iter())125.enumerate()126.skip(skip_rows)127.for_each(|(idx, (value, time))| {128if let (Some(time), Some(value)) = (time, value) {129let result = update(value, prev_result, time, prev_time, half_life);130prev_time = time;131prev_result = result;132unsafe {133let out_idx = sorting_indices.get_unchecked(idx);134*out.get_unchecked_mut(*out_idx as usize) = result;135}136};137});138let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest()));139if (times.null_count() > 0) || (values.null_count() > 0) {140let validity = binary_concatenate_validities(times, values);141arr = arr.with_validity_typed(validity);142}143ChunkedArray::with_chunk(values.name().clone(), arr)144}145146/// Fastpath if `times` is known to already be sorted.147fn ewm_mean_by_impl_sorted<T>(148values: &ChunkedArray<T>,149times: &Int64Chunked,150half_life: i64,151) -> ChunkedArray<T>152where153T: PolarsFloatType,154T::Native: Float + Zero + One,155{156let mut out: Vec<_> = zeroed_vec(times.len());157158let mut skip_rows: usize = 0;159let mut prev_time: i64 = 0;160let mut prev_result = T::Native::zero();161for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {162if let (Some(time), Some(value)) = (time, value) {163prev_time = time;164prev_result = value;165unsafe {166*out.get_unchecked_mut(idx) = prev_result;167}168skip_rows = idx + 1;169break;170}171}172values173.iter()174.zip(times.iter())175.enumerate()176.skip(skip_rows)177.for_each(|(idx, (value, time))| {178if let (Some(time), Some(value)) = (time, value) {179let result = update(value, prev_result, time, prev_time, half_life);180prev_time = time;181prev_result = result;182unsafe {183*out.get_unchecked_mut(idx) = result;184}185};186});187let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(CompatLevel::newest()));188if (times.null_count() > 0) || (values.null_count() > 0) {189let validity = binary_concatenate_validities(times, values);190arr = arr.with_validity_typed(validity);191}192ChunkedArray::with_chunk(values.name().clone(), arr)193}194195fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 {196match time_unit {197TimeUnit::Milliseconds => half_life / 1_000_000,198TimeUnit::Microseconds => half_life / 1_000,199TimeUnit::Nanoseconds => half_life,200}201}202203fn update<T>(value: T, prev_result: T, time: i64, prev_time: i64, half_life: i64) -> T204where205T: Float + Zero + One + FromPrimitive,206{207if value != prev_result {208let delta_time = time - prev_time;209// equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life)210let one_minus_alpha = T::from_f64(0.5)211.unwrap()212.powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap());213let alpha = T::one() - one_minus_alpha;214alpha * value + one_minus_alpha * prev_result215} else {216value217}218}219220221