Path: blob/main/crates/polars-ops/src/series/ops/horizontal.rs
6939 views
use std::borrow::Cow;12use polars_core::chunked_array::cast::CastOptions;3use polars_core::prelude::*;4use polars_core::series::arithmetic::coerce_lhs_rhs;5use polars_core::utils::dtypes_to_supertype;6use polars_core::{POOL, with_match_physical_numeric_polars_type};7use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};89fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> {10let mut length = 1;11for c in cs {12let len = c.len();13if len != 1 && len != length {14if length == 1 {15length = len;16} else {17polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})");18}19}20}21Ok(())22}2324pub trait MinMaxHorizontal {25/// Aggregate the column horizontally to their min values.26fn min_horizontal(&self) -> PolarsResult<Option<Column>>;27/// Aggregate the column horizontally to their max values.28fn max_horizontal(&self) -> PolarsResult<Option<Column>>;29}3031impl MinMaxHorizontal for DataFrame {32fn min_horizontal(&self) -> PolarsResult<Option<Column>> {33min_horizontal(self.get_columns())34}35fn max_horizontal(&self) -> PolarsResult<Option<Column>> {36max_horizontal(self.get_columns())37}38}3940#[derive(Copy, Clone, Debug, PartialEq)]41pub enum NullStrategy {42Ignore,43Propagate,44}4546pub trait SumMeanHorizontal {47/// Sum all values horizontally across columns.48fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;4950/// Compute the mean of all numeric values horizontally across columns.51fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;52}5354impl SumMeanHorizontal for DataFrame {55fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {56sum_horizontal(self.get_columns(), null_strategy)57}58fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {59mean_horizontal(self.get_columns(), null_strategy)60}61}6263fn min_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>64where65T: PolarsNumericType,66T::Native: PartialOrd,67{68let op = |l: T::Native, r: T::Native| {69if l < r { l } else { r }70};71arity::binary_elementwise_values(left, right, op)72}7374fn max_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>75where76T: PolarsNumericType,77T::Native: PartialOrd,78{79let op = |l: T::Native, r: T::Native| {80if l > r { l } else { r }81};82arity::binary_elementwise_values(left, right, op)83}8485fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult<Column> {86if left.dtype().to_physical().is_primitive_numeric()87&& right.dtype().to_physical().is_primitive_numeric()88&& left.null_count() == 089&& right.null_count() == 090&& left.len() == right.len()91{92match (left, right) {93(Column::Series(left), Column::Series(right)) => {94let (lhs, rhs) = coerce_lhs_rhs(left, right)?;95let logical = lhs.dtype();96let lhs = lhs.to_physical_repr();97let rhs = rhs.to_physical_repr();9899with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {100let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();101let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();102103unsafe {104if min {105min_binary(a, b).into_series().from_physical_unchecked(logical)106} else {107max_binary(a, b).into_series().from_physical_unchecked(logical)108}109}110})111.map(Column::from)112},113_ => {114let mask = if min {115left.lt(right)?116} else {117left.gt(right)?118};119120left.zip_with(&mask, right)121},122}123} else {124let mask = if min {125left.lt(right)? & left.is_not_null() | right.is_null()126} else {127left.gt(right)? & left.is_not_null() | right.is_null()128};129left.zip_with(&mask, right)130}131}132133pub fn max_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {134validate_column_lengths(columns)?;135136let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false);137138match columns.len() {1390 => Ok(None),1401 => Ok(Some(columns[0].clone())),1412 => max_fn(&columns[0], &columns[1]).map(Some),142_ => {143// the try_reduce_with is a bit slower in parallelism,144// but I don't think it matters here as we parallelize over columns, not over elements145POOL.install(|| {146columns147.par_iter()148.map(|s| Ok(Cow::Borrowed(s)))149.try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned))150// we can unwrap the option, because we are certain there is a column151// we started this operation on 3 columns152.unwrap()153.map(|cow| Some(cow.into_owned()))154})155},156}157}158159pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {160validate_column_lengths(columns)?;161162let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true);163164match columns.len() {1650 => Ok(None),1661 => Ok(Some(columns[0].clone())),1672 => min_fn(&columns[0], &columns[1]).map(Some),168_ => {169// the try_reduce_with is a bit slower in parallelism,170// but I don't think it matters here as we parallelize over columns, not over elements171POOL.install(|| {172columns173.par_iter()174.map(|s| Ok(Cow::Borrowed(s)))175.try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned))176// we can unwrap the option, because we are certain there is a column177// we started this operation on 3 columns178.unwrap()179.map(|cow| Some(cow.into_owned()))180})181},182}183}184185pub fn sum_horizontal(186columns: &[Column],187null_strategy: NullStrategy,188) -> PolarsResult<Option<Column>> {189validate_column_lengths(columns)?;190let ignore_nulls = null_strategy == NullStrategy::Ignore;191192let apply_null_strategy = |s: Series| -> PolarsResult<Series> {193if ignore_nulls && s.null_count() > 0 {194s.fill_null(FillNullStrategy::Zero)195} else {196Ok(s)197}198};199200let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {201let acc: Series = apply_null_strategy(acc)?;202let s = apply_null_strategy(s)?;203// This will do owned arithmetic and can be mutable204std::ops::Add::add(acc, s)205};206207// @scalar-opt208let non_null_cols = columns209.iter()210.filter(|x| x.dtype() != &DataType::Null)211.map(|c| c.as_materialized_series())212.collect::<Vec<_>>();213214// If we have any null columns and null strategy is not `Ignore`, we can return immediately.215if !ignore_nulls && non_null_cols.len() < columns.len() {216// We must determine the correct return dtype.217let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {218DataType::Boolean => IDX_DTYPE,219dt => dt,220};221return Ok(Some(Column::full_null(222columns[0].name().clone(),223columns[0].len(),224&return_dtype,225)));226}227228match non_null_cols.len() {2290 => {230if columns.is_empty() {231Ok(None)232} else {233// all columns are null dtype, so result is null dtype234Ok(Some(columns[0].clone()))235}236},2371 => Ok(Some(238apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {239non_null_cols[0].cast(&IDX_DTYPE)?240} else {241non_null_cols[0].clone()242})?243.into(),244)),2452 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())246.map(Column::from)247.map(Some),248_ => {249// the try_reduce_with is a bit slower in parallelism,250// but I don't think it matters here as we parallelize over columns, not over elements251let out = POOL.install(|| {252non_null_cols253.into_par_iter()254.cloned()255.map(Ok)256.try_reduce_with(sum_fn)257// We can unwrap because we started with at least 3 columns, so we always get a Some258.unwrap()259});260out.map(Column::from).map(Some)261},262}263}264265pub fn mean_horizontal(266columns: &[Column],267null_strategy: NullStrategy,268) -> PolarsResult<Option<Column>> {269validate_column_lengths(columns)?;270271let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {272let dtype = s.dtype();273dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()274});275276if !non_numeric_columns.is_empty() {277let col = non_numeric_columns.first().cloned();278polars_bail!(279InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",280col.unwrap().name(),281col.unwrap().dtype(),282);283}284let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();285let num_rows = columns.len();286match num_rows {2870 => Ok(None),2881 => Ok(Some(match columns[0].dtype() {289dt if dt != &DataType::Float32 && !dt.is_decimal() => {290columns[0].cast(&DataType::Float64)?291},292_ => columns[0].clone(),293})),294_ => {295let sum = || sum_horizontal(columns.as_slice(), null_strategy);296let null_count = || {297columns298.par_iter()299.map(|c| {300c.is_null()301.into_column()302.cast_with_options(&DataType::UInt32, CastOptions::NonStrict)303})304.reduce_with(|l, r| {305let l = l?;306let r = r?;307let result = std::ops::Add::add(&l, &r)?;308PolarsResult::Ok(result)309})310// we can unwrap the option, because we are certain there is a column311// we started this operation on 2 columns312.unwrap()313};314315let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count));316let sum = sum?;317let null_count = null_count?;318319// value lengths: len - null_count320let value_length: UInt32Chunked = (Column::new_scalar(321PlSmallStr::EMPTY,322Scalar::from(num_rows as u32),323null_count.len(),324) - null_count)?325.u32()326.unwrap()327.clone();328329// make sure that we do not divide by zero330// by replacing with None331let dt = if sum332.as_ref()333.is_some_and(|s| s.dtype() == &DataType::Float32)334{335&DataType::Float32336} else {337&DataType::Float64338};339let value_length = value_length340.set(&value_length.equal(0), None)?341.into_column()342.cast(dt)?;343344sum.map(|sum| std::ops::Div::div(&sum, &value_length))345.transpose()346},347}348}349350pub fn coalesce_columns(s: &[Column]) -> PolarsResult<Column> {351// TODO! this can be faster if we have more than two inputs.352polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list");353let mut out = s[0].clone();354for s in s {355if !out.null_count() == 0 {356return Ok(out);357} else {358let mask = out.is_not_null();359out = out360.as_materialized_series()361.zip_with_same_type(&mask, s.as_materialized_series())?362.into();363}364}365Ok(out)366}367368#[cfg(test)]369mod tests {370use super::*;371372#[test]373#[cfg_attr(miri, ignore)]374fn test_horizontal_agg() {375let a = Column::new("a".into(), [1, 2, 6]);376let b = Column::new("b".into(), [Some(1), None, None]);377let c = Column::new("c".into(), [Some(4), None, Some(3)]);378379let df = DataFrame::new(vec![a, b, c]).unwrap();380assert_eq!(381Vec::from(382df.mean_horizontal(NullStrategy::Ignore)383.unwrap()384.unwrap()385.f64()386.unwrap()387),388&[Some(2.0), Some(2.0), Some(4.5)]389);390assert_eq!(391Vec::from(392df.sum_horizontal(NullStrategy::Ignore)393.unwrap()394.unwrap()395.i32()396.unwrap()397),398&[Some(6), Some(2), Some(9)]399);400assert_eq!(401Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()),402&[Some(1), Some(2), Some(3)]403);404assert_eq!(405Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()),406&[Some(4), Some(2), Some(6)]407);408}409}410411412