Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-core/src/frame/arithmetic.rs
8430 views
1
use std::ops::{Add, Div, Mul, Rem, Sub};
2
3
use rayon::prelude::*;
4
5
use crate::POOL;
6
use crate::prelude::*;
7
use crate::utils::try_get_supertype;
8
9
/// Get the supertype that is valid for all columns in the [`DataFrame`].
10
/// This reduces casting of the rhs in arithmetic.
11
fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult<DataType> {
12
df.columns().iter().try_fold(rhs.dtype().clone(), |dt, s| {
13
try_get_supertype(s.dtype(), &dt)
14
})
15
}
16
17
macro_rules! impl_arithmetic {
18
($self:expr, $rhs:expr, $operand:expr) => {{
19
let st = get_supertype_all($self, $rhs)?;
20
let rhs = $rhs.cast(&st)?;
21
let cols = $self.try_apply_columns_par(|c| {
22
let s = c.as_materialized_series();
23
$operand(&s.cast(&st)?, &rhs).map(Column::from)
24
})?;
25
Ok(unsafe { DataFrame::new_unchecked($self.height(), cols) })
26
}};
27
}
28
29
impl Add<&Series> for &DataFrame {
30
type Output = PolarsResult<DataFrame>;
31
32
fn add(self, rhs: &Series) -> Self::Output {
33
impl_arithmetic!(self, rhs, std::ops::Add::add)
34
}
35
}
36
37
impl Add<&Series> for DataFrame {
38
type Output = PolarsResult<DataFrame>;
39
40
fn add(self, rhs: &Series) -> Self::Output {
41
(&self).add(rhs)
42
}
43
}
44
45
impl Sub<&Series> for &DataFrame {
46
type Output = PolarsResult<DataFrame>;
47
48
fn sub(self, rhs: &Series) -> Self::Output {
49
impl_arithmetic!(self, rhs, std::ops::Sub::sub)
50
}
51
}
52
53
impl Sub<&Series> for DataFrame {
54
type Output = PolarsResult<DataFrame>;
55
56
fn sub(self, rhs: &Series) -> Self::Output {
57
(&self).sub(rhs)
58
}
59
}
60
61
impl Mul<&Series> for &DataFrame {
62
type Output = PolarsResult<DataFrame>;
63
64
fn mul(self, rhs: &Series) -> Self::Output {
65
impl_arithmetic!(self, rhs, std::ops::Mul::mul)
66
}
67
}
68
69
impl Mul<&Series> for DataFrame {
70
type Output = PolarsResult<DataFrame>;
71
72
fn mul(self, rhs: &Series) -> Self::Output {
73
(&self).mul(rhs)
74
}
75
}
76
77
impl Div<&Series> for &DataFrame {
78
type Output = PolarsResult<DataFrame>;
79
80
fn div(self, rhs: &Series) -> Self::Output {
81
impl_arithmetic!(self, rhs, std::ops::Div::div)
82
}
83
}
84
85
impl Div<&Series> for DataFrame {
86
type Output = PolarsResult<DataFrame>;
87
88
fn div(self, rhs: &Series) -> Self::Output {
89
(&self).div(rhs)
90
}
91
}
92
93
impl Rem<&Series> for &DataFrame {
94
type Output = PolarsResult<DataFrame>;
95
96
fn rem(self, rhs: &Series) -> Self::Output {
97
impl_arithmetic!(self, rhs, std::ops::Rem::rem)
98
}
99
}
100
101
impl Rem<&Series> for DataFrame {
102
type Output = PolarsResult<DataFrame>;
103
104
fn rem(self, rhs: &Series) -> Self::Output {
105
(&self).rem(rhs)
106
}
107
}
108
109
impl DataFrame {
110
fn binary_aligned(
111
&self,
112
other: &DataFrame,
113
f: &(dyn Fn(&Series, &Series) -> PolarsResult<Series> + Sync + Send),
114
) -> PolarsResult<DataFrame> {
115
let max_len = std::cmp::max(self.height(), other.height());
116
let max_width = std::cmp::max(self.width(), other.width());
117
let cols = self
118
.columns()
119
.par_iter()
120
.zip(other.columns().par_iter())
121
.map(|(l, r)| {
122
let l = l.as_materialized_series();
123
let r = r.as_materialized_series();
124
125
let diff_l = max_len - l.len();
126
let diff_r = max_len - r.len();
127
128
let st = try_get_supertype(l.dtype(), r.dtype())?;
129
let mut l = l.cast(&st)?;
130
let mut r = r.cast(&st)?;
131
132
if diff_l > 0 {
133
l = l.extend_constant(AnyValue::Null, diff_l)?;
134
};
135
if diff_r > 0 {
136
r = r.extend_constant(AnyValue::Null, diff_r)?;
137
};
138
139
f(&l, &r).map(Column::from)
140
});
141
let mut cols = POOL.install(|| cols.collect::<PolarsResult<Vec<_>>>())?;
142
143
let col_len = cols.len();
144
if col_len < max_width {
145
let df = if col_len < self.width() { self } else { other };
146
147
for i in col_len..max_len {
148
let s = &df.columns().get(i).ok_or_else(|| polars_err!(InvalidOperation: "cannot do arithmetic on DataFrames with shapes: {:?} and {:?}", self.shape(), other.shape()))?;
149
let name = s.name();
150
let dtype = s.dtype();
151
152
// trick to fill a series with nulls
153
let vals: &[Option<i32>] = &[None];
154
let s = Series::new(name.clone(), vals).cast(dtype)?;
155
cols.push(s.new_from_index(0, max_len).into())
156
}
157
}
158
159
DataFrame::new_infer_height(cols)
160
}
161
}
162
163
impl Add<&DataFrame> for &DataFrame {
164
type Output = PolarsResult<DataFrame>;
165
166
fn add(self, rhs: &DataFrame) -> Self::Output {
167
self.binary_aligned(rhs, &|a, b| a + b)
168
}
169
}
170
171
impl Sub<&DataFrame> for &DataFrame {
172
type Output = PolarsResult<DataFrame>;
173
174
fn sub(self, rhs: &DataFrame) -> Self::Output {
175
self.binary_aligned(rhs, &|a, b| a - b)
176
}
177
}
178
179
impl Div<&DataFrame> for &DataFrame {
180
type Output = PolarsResult<DataFrame>;
181
182
fn div(self, rhs: &DataFrame) -> Self::Output {
183
self.binary_aligned(rhs, &|a, b| a / b)
184
}
185
}
186
187
impl Mul<&DataFrame> for &DataFrame {
188
type Output = PolarsResult<DataFrame>;
189
190
fn mul(self, rhs: &DataFrame) -> Self::Output {
191
self.binary_aligned(rhs, &|a, b| a * b)
192
}
193
}
194
195
impl Rem<&DataFrame> for &DataFrame {
196
type Output = PolarsResult<DataFrame>;
197
198
fn rem(self, rhs: &DataFrame) -> Self::Output {
199
self.binary_aligned(rhs, &|a, b| a % b)
200
}
201
}
202
203