Path: blob/main/crates/polars-core/src/frame/arithmetic.rs
6940 views
use std::ops::{Add, Div, Mul, Rem, Sub};12use rayon::prelude::*;34use crate::POOL;5use crate::prelude::*;6use crate::utils::try_get_supertype;78/// Get the supertype that is valid for all columns in the [`DataFrame`].9/// This reduces casting of the rhs in arithmetic.10fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult<DataType> {11df.columns.iter().try_fold(rhs.dtype().clone(), |dt, s| {12try_get_supertype(s.dtype(), &dt)13})14}1516macro_rules! impl_arithmetic {17($self:expr, $rhs:expr, $operand:expr) => {{18let st = get_supertype_all($self, $rhs)?;19let rhs = $rhs.cast(&st)?;20let cols = POOL.install(|| {21$self22.par_materialized_column_iter()23.map(|s| $operand(&s.cast(&st)?, &rhs))24.map(|s| s.map(Column::from))25.collect::<PolarsResult<_>>()26})?;27Ok(unsafe { DataFrame::new_no_checks($self.height(), cols) })28}};29}3031impl Add<&Series> for &DataFrame {32type Output = PolarsResult<DataFrame>;3334fn add(self, rhs: &Series) -> Self::Output {35impl_arithmetic!(self, rhs, std::ops::Add::add)36}37}3839impl Add<&Series> for DataFrame {40type Output = PolarsResult<DataFrame>;4142fn add(self, rhs: &Series) -> Self::Output {43(&self).add(rhs)44}45}4647impl Sub<&Series> for &DataFrame {48type Output = PolarsResult<DataFrame>;4950fn sub(self, rhs: &Series) -> Self::Output {51impl_arithmetic!(self, rhs, std::ops::Sub::sub)52}53}5455impl Sub<&Series> for DataFrame {56type Output = PolarsResult<DataFrame>;5758fn sub(self, rhs: &Series) -> Self::Output {59(&self).sub(rhs)60}61}6263impl Mul<&Series> for &DataFrame {64type Output = PolarsResult<DataFrame>;6566fn mul(self, rhs: &Series) -> Self::Output {67impl_arithmetic!(self, rhs, std::ops::Mul::mul)68}69}7071impl Mul<&Series> for DataFrame {72type Output = PolarsResult<DataFrame>;7374fn mul(self, rhs: &Series) -> Self::Output {75(&self).mul(rhs)76}77}7879impl Div<&Series> for &DataFrame {80type Output = PolarsResult<DataFrame>;8182fn div(self, rhs: &Series) -> Self::Output {83impl_arithmetic!(self, rhs, std::ops::Div::div)84}85}8687impl Div<&Series> for DataFrame {88type Output = PolarsResult<DataFrame>;8990fn div(self, rhs: &Series) -> Self::Output {91(&self).div(rhs)92}93}9495impl Rem<&Series> for &DataFrame {96type Output = PolarsResult<DataFrame>;9798fn rem(self, rhs: &Series) -> Self::Output {99impl_arithmetic!(self, rhs, std::ops::Rem::rem)100}101}102103impl Rem<&Series> for DataFrame {104type Output = PolarsResult<DataFrame>;105106fn rem(self, rhs: &Series) -> Self::Output {107(&self).rem(rhs)108}109}110111impl DataFrame {112fn binary_aligned(113&self,114other: &DataFrame,115f: &(dyn Fn(&Series, &Series) -> PolarsResult<Series> + Sync + Send),116) -> PolarsResult<DataFrame> {117let max_len = std::cmp::max(self.height(), other.height());118let max_width = std::cmp::max(self.width(), other.width());119let cols = self120.get_columns()121.par_iter()122.zip(other.get_columns().par_iter())123.map(|(l, r)| {124let l = l.as_materialized_series();125let r = r.as_materialized_series();126127let diff_l = max_len - l.len();128let diff_r = max_len - r.len();129130let st = try_get_supertype(l.dtype(), r.dtype())?;131let mut l = l.cast(&st)?;132let mut r = r.cast(&st)?;133134if diff_l > 0 {135l = l.extend_constant(AnyValue::Null, diff_l)?;136};137if diff_r > 0 {138r = r.extend_constant(AnyValue::Null, diff_r)?;139};140141f(&l, &r).map(Column::from)142});143let mut cols = POOL.install(|| cols.collect::<PolarsResult<Vec<_>>>())?;144145let col_len = cols.len();146if col_len < max_width {147let df = if col_len < self.width() { self } else { other };148149for i in col_len..max_len {150let s = &df.get_columns().get(i).ok_or_else(|| polars_err!(InvalidOperation: "cannot do arithmetic on DataFrames with shapes: {:?} and {:?}", self.shape(), other.shape()))?;151let name = s.name();152let dtype = s.dtype();153154// trick to fill a series with nulls155let vals: &[Option<i32>] = &[None];156let s = Series::new(name.clone(), vals).cast(dtype)?;157cols.push(s.new_from_index(0, max_len).into())158}159}160DataFrame::new(cols)161}162}163164impl Add<&DataFrame> for &DataFrame {165type Output = PolarsResult<DataFrame>;166167fn add(self, rhs: &DataFrame) -> Self::Output {168self.binary_aligned(rhs, &|a, b| a + b)169}170}171172impl Sub<&DataFrame> for &DataFrame {173type Output = PolarsResult<DataFrame>;174175fn sub(self, rhs: &DataFrame) -> Self::Output {176self.binary_aligned(rhs, &|a, b| a - b)177}178}179180impl Div<&DataFrame> for &DataFrame {181type Output = PolarsResult<DataFrame>;182183fn div(self, rhs: &DataFrame) -> Self::Output {184self.binary_aligned(rhs, &|a, b| a / b)185}186}187188impl Mul<&DataFrame> for &DataFrame {189type Output = PolarsResult<DataFrame>;190191fn mul(self, rhs: &DataFrame) -> Self::Output {192self.binary_aligned(rhs, &|a, b| a * b)193}194}195196impl Rem<&DataFrame> for &DataFrame {197type Output = PolarsResult<DataFrame>;198199fn rem(self, rhs: &DataFrame) -> Self::Output {200self.binary_aligned(rhs, &|a, b| a % b)201}202}203204205