Path: blob/main/crates/polars-expr/src/dispatch/rolling.rs
7884 views
use std::ops::BitAnd;12use arrow::temporal_conversions::MICROSECONDS_IN_DAY as US_IN_DAY;3use polars_core::error::PolarsResult;4use polars_core::prelude::{5AnyValue, ChunkCast, Column, DataType, IntoColumn, NamedFrom, RollingOptionsFixedWindow,6TimeUnit,7};8use polars_core::scalar::Scalar;9use polars_core::series::Series;10#[cfg(feature = "cov")]11use polars_plan::dsl::RollingCovOptions;12use polars_plan::prelude::PlanCallback;13use polars_time::prelude::SeriesOpsTime;14use polars_utils::pl_str::PlSmallStr;1516fn roll_with_temporal_conversion<F: FnOnce(&Series) -> PolarsResult<Series>>(17s: &Column,18op: F,19) -> PolarsResult<Column> {20let dt = s.dtype();21let s = if dt.is_temporal() {22&s.to_physical_repr()23} else {24s25};2627// @scalar-opt28let out = op(s.as_materialized_series())?;2930Ok(match dt {31DataType::Date => (out * US_IN_DAY as f64)32.cast(&DataType::Int64)?33.into_datetime(TimeUnit::Microseconds, None),34DataType::Datetime(tu, tz) => out.cast(&DataType::Int64)?.into_datetime(*tu, tz.clone()),35DataType::Duration(tu) => out.cast(&DataType::Int64)?.into_duration(*tu),36DataType::Time => out.cast(&DataType::Int64)?.into_time(),37_ => out,38}39.into_column())40}4142pub(super) fn rolling_min(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {43// @scalar-opt44s.as_materialized_series()45.rolling_min(options)46.map(Column::from)47}4849pub(super) fn rolling_max(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {50// @scalar-opt51s.as_materialized_series()52.rolling_max(options)53.map(Column::from)54}5556pub(super) fn rolling_mean(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {57roll_with_temporal_conversion(s, |s| s.rolling_mean(options))58}5960pub(super) fn rolling_sum(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {61// @scalar-opt62s.as_materialized_series()63.rolling_sum(options)64.map(Column::from)65}6667pub(super) fn rolling_quantile(68s: &Column,69options: RollingOptionsFixedWindow,70) -> PolarsResult<Column> {71roll_with_temporal_conversion(s, |s| s.rolling_quantile(options))72}7374pub(super) fn rolling_var(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {75// @scalar-opt76s.as_materialized_series()77.rolling_var(options)78.map(Column::from)79}8081pub(super) fn rolling_std(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {82// @scalar-opt83s.as_materialized_series()84.rolling_std(options)85.map(Column::from)86}8788pub(super) fn rolling_rank(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {89// @scalar-opt90s.as_materialized_series()91.rolling_rank(options)92.map(Column::from)93}9495#[cfg(feature = "moment")]96pub(super) fn rolling_skew(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {97// @scalar-opt98let s = s.as_materialized_series();99polars_ops::series::rolling_skew(s, options).map(Column::from)100}101102#[cfg(feature = "moment")]103pub(super) fn rolling_kurtosis(104s: &Column,105options: RollingOptionsFixedWindow,106) -> PolarsResult<Column> {107// @scalar-opt108let s = s.as_materialized_series();109polars_ops::series::rolling_kurtosis(s, options).map(Column::from)110}111112#[cfg(feature = "cov")]113fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series {114match dtype {115DataType::Float64 => {116let values = (0..len)117.map(|v| std::cmp::min(window_size, v + 1) as f64)118.collect::<Vec<_>>();119Series::new(PlSmallStr::EMPTY, values)120},121DataType::Float32 => {122let values = (0..len)123.map(|v| std::cmp::min(window_size, v + 1) as f32)124.collect::<Vec<_>>();125Series::new(PlSmallStr::EMPTY, values)126},127#[cfg(feature = "dtype-f16")]128DataType::Float16 => {129use num_traits::AsPrimitive;130use polars_utils::float16::pf16;131let values = (0..len)132.map(|v| std::cmp::min(window_size, v + 1))133.map(AsPrimitive::<pf16>::as_)134.collect::<Vec<_>>();135Series::new(PlSmallStr::EMPTY, values)136},137_ => unreachable!(),138}139}140141#[cfg(feature = "cov")]142pub(super) fn rolling_corr_cov(143s: &[Column],144rolling_options: RollingOptionsFixedWindow,145cov_options: RollingCovOptions,146is_corr: bool,147) -> PolarsResult<Column> {148let mut x = s[0].as_materialized_series().rechunk();149let mut y = s[1].as_materialized_series().rechunk();150151if !x.dtype().is_float() {152x = x.cast(&DataType::Float64)?;153}154if !y.dtype().is_float() {155y = y.cast(&DataType::Float64)?;156}157let dtype = x.dtype().clone();158159let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?;160let rolling_options_count = RollingOptionsFixedWindow {161window_size: rolling_options.window_size,162min_periods: 0,163..Default::default()164};165166let count_x_y = if (x.null_count() + y.null_count()) > 0 {167// mask out nulls on both sides before compute mean/var168169let valids = x.is_not_null().bitand(y.is_not_null());170let valids_arr = valids.downcast_as_array();171let valids_bitmap = valids_arr.values();172173unsafe {174let xarr = &mut x.chunks_mut()[0];175*xarr = xarr.with_validity(Some(valids_bitmap.clone()));176let yarr = &mut y.chunks_mut()[0];177*yarr = yarr.with_validity(Some(valids_bitmap.clone()));178x.compute_len();179y.compute_len();180}181valids182.cast(&dtype)183.unwrap()184.rolling_sum(rolling_options_count)?185} else {186det_count_x_y(rolling_options.window_size, x.len(), &dtype)187};188189let mean_x = x.rolling_mean(rolling_options.clone())?;190let mean_y = y.rolling_mean(rolling_options.clone())?;191let ddof = Series::new(192PlSmallStr::EMPTY,193&[AnyValue::from(cov_options.ddof).cast(&dtype)],194);195196let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap()197* (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap())198.unwrap();199200if is_corr {201let var_x = x.rolling_var(rolling_options.clone())?;202let var_y = y.rolling_var(rolling_options)?;203204let base = (var_x * var_y).unwrap();205let sc = Scalar::new(206base.dtype().clone(),207AnyValue::Float64(0.5).cast(&dtype).into_static(),208);209let denominator = super::pow::pow(&mut [base.into_column(), sc.into_column("".into())])210.unwrap()211.take_materialized_series();212213Ok((numerator / denominator)?.into_column())214} else {215Ok(numerator.into_column())216}217}218219pub fn rolling_map(220c: &Column,221rolling_options: RollingOptionsFixedWindow,222f: PlanCallback<Series, Series>,223) -> PolarsResult<Column> {224c.as_materialized_series()225.rolling_map(226&(|s: &Series| f.call(s.clone())?.strict_cast(s.dtype())) as &_,227rolling_options,228)229.map(Column::from)230}231232233