Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-core/src/frame/arithmetic.rs
6940 views
1
use std::ops::{Add, Div, Mul, Rem, Sub};
2
3
use rayon::prelude::*;
4
5
use crate::POOL;
6
use crate::prelude::*;
7
use crate::utils::try_get_supertype;
8
9
/// Get the supertype that is valid for all columns in the [`DataFrame`].
10
/// This reduces casting of the rhs in arithmetic.
11
fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult<DataType> {
12
df.columns.iter().try_fold(rhs.dtype().clone(), |dt, s| {
13
try_get_supertype(s.dtype(), &dt)
14
})
15
}
16
17
macro_rules! impl_arithmetic {
18
($self:expr, $rhs:expr, $operand:expr) => {{
19
let st = get_supertype_all($self, $rhs)?;
20
let rhs = $rhs.cast(&st)?;
21
let cols = POOL.install(|| {
22
$self
23
.par_materialized_column_iter()
24
.map(|s| $operand(&s.cast(&st)?, &rhs))
25
.map(|s| s.map(Column::from))
26
.collect::<PolarsResult<_>>()
27
})?;
28
Ok(unsafe { DataFrame::new_no_checks($self.height(), cols) })
29
}};
30
}
31
32
impl Add<&Series> for &DataFrame {
33
type Output = PolarsResult<DataFrame>;
34
35
fn add(self, rhs: &Series) -> Self::Output {
36
impl_arithmetic!(self, rhs, std::ops::Add::add)
37
}
38
}
39
40
impl Add<&Series> for DataFrame {
41
type Output = PolarsResult<DataFrame>;
42
43
fn add(self, rhs: &Series) -> Self::Output {
44
(&self).add(rhs)
45
}
46
}
47
48
impl Sub<&Series> for &DataFrame {
49
type Output = PolarsResult<DataFrame>;
50
51
fn sub(self, rhs: &Series) -> Self::Output {
52
impl_arithmetic!(self, rhs, std::ops::Sub::sub)
53
}
54
}
55
56
impl Sub<&Series> for DataFrame {
57
type Output = PolarsResult<DataFrame>;
58
59
fn sub(self, rhs: &Series) -> Self::Output {
60
(&self).sub(rhs)
61
}
62
}
63
64
impl Mul<&Series> for &DataFrame {
65
type Output = PolarsResult<DataFrame>;
66
67
fn mul(self, rhs: &Series) -> Self::Output {
68
impl_arithmetic!(self, rhs, std::ops::Mul::mul)
69
}
70
}
71
72
impl Mul<&Series> for DataFrame {
73
type Output = PolarsResult<DataFrame>;
74
75
fn mul(self, rhs: &Series) -> Self::Output {
76
(&self).mul(rhs)
77
}
78
}
79
80
impl Div<&Series> for &DataFrame {
81
type Output = PolarsResult<DataFrame>;
82
83
fn div(self, rhs: &Series) -> Self::Output {
84
impl_arithmetic!(self, rhs, std::ops::Div::div)
85
}
86
}
87
88
impl Div<&Series> for DataFrame {
89
type Output = PolarsResult<DataFrame>;
90
91
fn div(self, rhs: &Series) -> Self::Output {
92
(&self).div(rhs)
93
}
94
}
95
96
impl Rem<&Series> for &DataFrame {
97
type Output = PolarsResult<DataFrame>;
98
99
fn rem(self, rhs: &Series) -> Self::Output {
100
impl_arithmetic!(self, rhs, std::ops::Rem::rem)
101
}
102
}
103
104
impl Rem<&Series> for DataFrame {
105
type Output = PolarsResult<DataFrame>;
106
107
fn rem(self, rhs: &Series) -> Self::Output {
108
(&self).rem(rhs)
109
}
110
}
111
112
impl DataFrame {
113
fn binary_aligned(
114
&self,
115
other: &DataFrame,
116
f: &(dyn Fn(&Series, &Series) -> PolarsResult<Series> + Sync + Send),
117
) -> PolarsResult<DataFrame> {
118
let max_len = std::cmp::max(self.height(), other.height());
119
let max_width = std::cmp::max(self.width(), other.width());
120
let cols = self
121
.get_columns()
122
.par_iter()
123
.zip(other.get_columns().par_iter())
124
.map(|(l, r)| {
125
let l = l.as_materialized_series();
126
let r = r.as_materialized_series();
127
128
let diff_l = max_len - l.len();
129
let diff_r = max_len - r.len();
130
131
let st = try_get_supertype(l.dtype(), r.dtype())?;
132
let mut l = l.cast(&st)?;
133
let mut r = r.cast(&st)?;
134
135
if diff_l > 0 {
136
l = l.extend_constant(AnyValue::Null, diff_l)?;
137
};
138
if diff_r > 0 {
139
r = r.extend_constant(AnyValue::Null, diff_r)?;
140
};
141
142
f(&l, &r).map(Column::from)
143
});
144
let mut cols = POOL.install(|| cols.collect::<PolarsResult<Vec<_>>>())?;
145
146
let col_len = cols.len();
147
if col_len < max_width {
148
let df = if col_len < self.width() { self } else { other };
149
150
for i in col_len..max_len {
151
let s = &df.get_columns().get(i).ok_or_else(|| polars_err!(InvalidOperation: "cannot do arithmetic on DataFrames with shapes: {:?} and {:?}", self.shape(), other.shape()))?;
152
let name = s.name();
153
let dtype = s.dtype();
154
155
// trick to fill a series with nulls
156
let vals: &[Option<i32>] = &[None];
157
let s = Series::new(name.clone(), vals).cast(dtype)?;
158
cols.push(s.new_from_index(0, max_len).into())
159
}
160
}
161
DataFrame::new(cols)
162
}
163
}
164
165
impl Add<&DataFrame> for &DataFrame {
166
type Output = PolarsResult<DataFrame>;
167
168
fn add(self, rhs: &DataFrame) -> Self::Output {
169
self.binary_aligned(rhs, &|a, b| a + b)
170
}
171
}
172
173
impl Sub<&DataFrame> for &DataFrame {
174
type Output = PolarsResult<DataFrame>;
175
176
fn sub(self, rhs: &DataFrame) -> Self::Output {
177
self.binary_aligned(rhs, &|a, b| a - b)
178
}
179
}
180
181
impl Div<&DataFrame> for &DataFrame {
182
type Output = PolarsResult<DataFrame>;
183
184
fn div(self, rhs: &DataFrame) -> Self::Output {
185
self.binary_aligned(rhs, &|a, b| a / b)
186
}
187
}
188
189
impl Mul<&DataFrame> for &DataFrame {
190
type Output = PolarsResult<DataFrame>;
191
192
fn mul(self, rhs: &DataFrame) -> Self::Output {
193
self.binary_aligned(rhs, &|a, b| a * b)
194
}
195
}
196
197
impl Rem<&DataFrame> for &DataFrame {
198
type Output = PolarsResult<DataFrame>;
199
200
fn rem(self, rhs: &DataFrame) -> Self::Output {
201
self.binary_aligned(rhs, &|a, b| a % b)
202
}
203
}
204
205