Path: blob/main/crates/polars-plan/src/dsl/functions/correlation.rs
6940 views
use super::*;12/// Compute the covariance between two columns.3pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {4let function = FunctionExpr::Correlation {5method: CorrelationMethod::Covariance(ddof),6};7a.map_binary(function, b)8}910/// Compute the pearson correlation between two columns.11pub fn pearson_corr(a: Expr, b: Expr) -> Expr {12let function = FunctionExpr::Correlation {13method: CorrelationMethod::Pearson,14};15a.map_binary(function, b)16}1718/// Compute the spearman rank correlation between two columns.19/// Missing data will be excluded from the computation.20/// # Arguments21/// * propagate_nans22/// If `true` any `NaN` encountered will lead to `NaN` in the output.23/// If to `false` then `NaN` are regarded as larger than any finite number24/// and thus lead to the highest rank.25#[cfg(all(feature = "rank", feature = "propagate_nans"))]26pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr {27let function = FunctionExpr::Correlation {28method: CorrelationMethod::SpearmanRank(propagate_nans),29};30a.map_binary(function, b)31}3233#[cfg(all(feature = "rolling_window", feature = "cov"))]34fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr {35// see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L180436let rolling_options = RollingOptionsFixedWindow {37window_size: options.window_size as usize,38min_periods: options.min_periods as usize,39..Default::default()40};4142Expr::Function {43input: vec![x, y],44function: FunctionExpr::RollingExpr {45function: RollingFunction::CorrCov {46corr_cov_options: options,47is_corr,48},49options: rolling_options,50},51}52}5354#[cfg(all(feature = "rolling_window", feature = "cov"))]55pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {56dispatch_corr_cov(x, y, options, true)57}5859#[cfg(all(feature = "rolling_window", feature = "cov"))]60pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {61dispatch_corr_cov(x, y, options, false)62}636465