Path: blob/main/crates/polars-core/src/frame/arithmetic.rs
8430 views
use std::ops::{Add, Div, Mul, Rem, Sub};12use rayon::prelude::*;34use crate::POOL;5use crate::prelude::*;6use crate::utils::try_get_supertype;78/// Get the supertype that is valid for all columns in the [`DataFrame`].9/// This reduces casting of the rhs in arithmetic.10fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult<DataType> {11df.columns().iter().try_fold(rhs.dtype().clone(), |dt, s| {12try_get_supertype(s.dtype(), &dt)13})14}1516macro_rules! impl_arithmetic {17($self:expr, $rhs:expr, $operand:expr) => {{18let st = get_supertype_all($self, $rhs)?;19let rhs = $rhs.cast(&st)?;20let cols = $self.try_apply_columns_par(|c| {21let s = c.as_materialized_series();22$operand(&s.cast(&st)?, &rhs).map(Column::from)23})?;24Ok(unsafe { DataFrame::new_unchecked($self.height(), cols) })25}};26}2728impl Add<&Series> for &DataFrame {29type Output = PolarsResult<DataFrame>;3031fn add(self, rhs: &Series) -> Self::Output {32impl_arithmetic!(self, rhs, std::ops::Add::add)33}34}3536impl Add<&Series> for DataFrame {37type Output = PolarsResult<DataFrame>;3839fn add(self, rhs: &Series) -> Self::Output {40(&self).add(rhs)41}42}4344impl Sub<&Series> for &DataFrame {45type Output = PolarsResult<DataFrame>;4647fn sub(self, rhs: &Series) -> Self::Output {48impl_arithmetic!(self, rhs, std::ops::Sub::sub)49}50}5152impl Sub<&Series> for DataFrame {53type Output = PolarsResult<DataFrame>;5455fn sub(self, rhs: &Series) -> Self::Output {56(&self).sub(rhs)57}58}5960impl Mul<&Series> for &DataFrame {61type Output = PolarsResult<DataFrame>;6263fn mul(self, rhs: &Series) -> Self::Output {64impl_arithmetic!(self, rhs, std::ops::Mul::mul)65}66}6768impl Mul<&Series> for DataFrame {69type Output = PolarsResult<DataFrame>;7071fn mul(self, rhs: &Series) -> Self::Output {72(&self).mul(rhs)73}74}7576impl Div<&Series> for &DataFrame {77type Output = PolarsResult<DataFrame>;7879fn div(self, rhs: &Series) -> Self::Output {80impl_arithmetic!(self, rhs, std::ops::Div::div)81}82}8384impl Div<&Series> for DataFrame {85type Output = PolarsResult<DataFrame>;8687fn div(self, rhs: &Series) -> Self::Output {88(&self).div(rhs)89}90}9192impl Rem<&Series> for &DataFrame {93type Output = PolarsResult<DataFrame>;9495fn rem(self, rhs: &Series) -> Self::Output {96impl_arithmetic!(self, rhs, std::ops::Rem::rem)97}98}99100impl Rem<&Series> for DataFrame {101type Output = PolarsResult<DataFrame>;102103fn rem(self, rhs: &Series) -> Self::Output {104(&self).rem(rhs)105}106}107108impl DataFrame {109fn binary_aligned(110&self,111other: &DataFrame,112f: &(dyn Fn(&Series, &Series) -> PolarsResult<Series> + Sync + Send),113) -> PolarsResult<DataFrame> {114let max_len = std::cmp::max(self.height(), other.height());115let max_width = std::cmp::max(self.width(), other.width());116let cols = self117.columns()118.par_iter()119.zip(other.columns().par_iter())120.map(|(l, r)| {121let l = l.as_materialized_series();122let r = r.as_materialized_series();123124let diff_l = max_len - l.len();125let diff_r = max_len - r.len();126127let st = try_get_supertype(l.dtype(), r.dtype())?;128let mut l = l.cast(&st)?;129let mut r = r.cast(&st)?;130131if diff_l > 0 {132l = l.extend_constant(AnyValue::Null, diff_l)?;133};134if diff_r > 0 {135r = r.extend_constant(AnyValue::Null, diff_r)?;136};137138f(&l, &r).map(Column::from)139});140let mut cols = POOL.install(|| cols.collect::<PolarsResult<Vec<_>>>())?;141142let col_len = cols.len();143if col_len < max_width {144let df = if col_len < self.width() { self } else { other };145146for i in col_len..max_len {147let s = &df.columns().get(i).ok_or_else(|| polars_err!(InvalidOperation: "cannot do arithmetic on DataFrames with shapes: {:?} and {:?}", self.shape(), other.shape()))?;148let name = s.name();149let dtype = s.dtype();150151// trick to fill a series with nulls152let vals: &[Option<i32>] = &[None];153let s = Series::new(name.clone(), vals).cast(dtype)?;154cols.push(s.new_from_index(0, max_len).into())155}156}157158DataFrame::new_infer_height(cols)159}160}161162impl Add<&DataFrame> for &DataFrame {163type Output = PolarsResult<DataFrame>;164165fn add(self, rhs: &DataFrame) -> Self::Output {166self.binary_aligned(rhs, &|a, b| a + b)167}168}169170impl Sub<&DataFrame> for &DataFrame {171type Output = PolarsResult<DataFrame>;172173fn sub(self, rhs: &DataFrame) -> Self::Output {174self.binary_aligned(rhs, &|a, b| a - b)175}176}177178impl Div<&DataFrame> for &DataFrame {179type Output = PolarsResult<DataFrame>;180181fn div(self, rhs: &DataFrame) -> Self::Output {182self.binary_aligned(rhs, &|a, b| a / b)183}184}185186impl Mul<&DataFrame> for &DataFrame {187type Output = PolarsResult<DataFrame>;188189fn mul(self, rhs: &DataFrame) -> Self::Output {190self.binary_aligned(rhs, &|a, b| a * b)191}192}193194impl Rem<&DataFrame> for &DataFrame {195type Output = PolarsResult<DataFrame>;196197fn rem(self, rhs: &DataFrame) -> Self::Output {198self.binary_aligned(rhs, &|a, b| a % b)199}200}201202203