Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/binary.rs
6940 views
1
use polars_core::POOL;
2
use polars_core::prelude::*;
3
#[cfg(feature = "round_series")]
4
use polars_ops::prelude::floor_div_series;
5
6
use super::*;
7
use crate::expressions::{
8
AggState, AggregationContext, PartitionedAggregation, PhysicalExpr, UpdateGroups,
9
};
10
11
#[derive(Clone)]
12
pub struct BinaryExpr {
13
left: Arc<dyn PhysicalExpr>,
14
op: Operator,
15
right: Arc<dyn PhysicalExpr>,
16
expr: Expr,
17
has_literal: bool,
18
allow_threading: bool,
19
is_scalar: bool,
20
output_field: Field,
21
}
22
23
impl BinaryExpr {
24
#[expect(clippy::too_many_arguments)]
25
pub fn new(
26
left: Arc<dyn PhysicalExpr>,
27
op: Operator,
28
right: Arc<dyn PhysicalExpr>,
29
expr: Expr,
30
has_literal: bool,
31
allow_threading: bool,
32
is_scalar: bool,
33
output_field: Field,
34
) -> Self {
35
Self {
36
left,
37
op,
38
right,
39
expr,
40
has_literal,
41
allow_threading,
42
is_scalar,
43
output_field,
44
}
45
}
46
}
47
48
/// Can partially do operations in place.
49
fn apply_operator_owned(left: Column, right: Column, op: Operator) -> PolarsResult<Column> {
50
match op {
51
Operator::Plus => left.try_add_owned(right),
52
Operator::Minus => left.try_sub_owned(right),
53
Operator::Multiply
54
if left.dtype().is_primitive_numeric() && right.dtype().is_primitive_numeric() =>
55
{
56
left.try_mul_owned(right)
57
},
58
_ => apply_operator(&left, &right, op),
59
}
60
}
61
62
pub fn apply_operator(left: &Column, right: &Column, op: Operator) -> PolarsResult<Column> {
63
use DataType::*;
64
match op {
65
Operator::Gt => ChunkCompareIneq::gt(left, right).map(|ca| ca.into_column()),
66
Operator::GtEq => ChunkCompareIneq::gt_eq(left, right).map(|ca| ca.into_column()),
67
Operator::Lt => ChunkCompareIneq::lt(left, right).map(|ca| ca.into_column()),
68
Operator::LtEq => ChunkCompareIneq::lt_eq(left, right).map(|ca| ca.into_column()),
69
Operator::Eq => ChunkCompareEq::equal(left, right).map(|ca| ca.into_column()),
70
Operator::NotEq => ChunkCompareEq::not_equal(left, right).map(|ca| ca.into_column()),
71
Operator::Plus => left + right,
72
Operator::Minus => left - right,
73
Operator::Multiply => left * right,
74
Operator::Divide => left / right,
75
Operator::TrueDivide => match left.dtype() {
76
#[cfg(feature = "dtype-decimal")]
77
Decimal(_, _) => left / right,
78
Duration(_) | Date | Datetime(_, _) | Float32 | Float64 => left / right,
79
#[cfg(feature = "dtype-array")]
80
Array(..) => left / right,
81
#[cfg(feature = "dtype-array")]
82
_ if right.dtype().is_array() => left / right,
83
List(_) => left / right,
84
_ if right.dtype().is_list() => left / right,
85
_ if left.dtype().is_string() || right.dtype().is_string() => {
86
polars_bail!(InvalidOperation: "cannot divide using strings")
87
},
88
_ => {
89
if right.dtype().is_temporal() {
90
return left / right;
91
}
92
left.cast(&Float64)? / right.cast(&Float64)?
93
},
94
},
95
Operator::FloorDivide => {
96
#[cfg(feature = "round_series")]
97
{
98
floor_div_series(
99
left.as_materialized_series(),
100
right.as_materialized_series(),
101
)
102
.map(Column::from)
103
}
104
#[cfg(not(feature = "round_series"))]
105
{
106
panic!("activate 'round_series' feature")
107
}
108
},
109
Operator::And => left.bitand(right),
110
Operator::Or => left.bitor(right),
111
Operator::LogicalOr => left
112
.cast(&DataType::Boolean)?
113
.bitor(&right.cast(&DataType::Boolean)?),
114
Operator::LogicalAnd => left
115
.cast(&DataType::Boolean)?
116
.bitand(&right.cast(&DataType::Boolean)?),
117
Operator::Xor => left.bitxor(right),
118
Operator::Modulus => left % right,
119
Operator::EqValidity => left.equal_missing(right).map(|ca| ca.into_column()),
120
Operator::NotEqValidity => left.not_equal_missing(right).map(|ca| ca.into_column()),
121
}
122
}
123
124
impl BinaryExpr {
125
fn apply_elementwise<'a>(
126
&self,
127
mut ac_l: AggregationContext<'a>,
128
ac_r: AggregationContext,
129
aggregated: bool,
130
) -> PolarsResult<AggregationContext<'a>> {
131
// We want to be able to mutate in place, so we take the lhs to make sure that we drop.
132
let lhs = ac_l.get_values().clone();
133
let rhs = ac_r.get_values().clone();
134
135
// Drop lhs so that we might operate in place.
136
drop(ac_l.take());
137
138
let out = apply_operator_owned(lhs, rhs, self.op)?;
139
ac_l.with_values(out, aggregated, Some(&self.expr))?;
140
Ok(ac_l)
141
}
142
143
fn apply_all_literal<'a>(
144
&self,
145
mut ac_l: AggregationContext<'a>,
146
mut ac_r: AggregationContext<'a>,
147
) -> PolarsResult<AggregationContext<'a>> {
148
let name = self.output_field.name().clone();
149
ac_l.groups();
150
ac_r.groups();
151
polars_ensure!(ac_l.groups.len() == ac_r.groups.len(), ComputeError: "lhs and rhs should have same group length");
152
let left_c = ac_l.get_values().rechunk().into_column();
153
let right_c = ac_r.get_values().rechunk().into_column();
154
let res_c = apply_operator(&left_c, &right_c, self.op)?;
155
ac_l.with_update_groups(UpdateGroups::WithSeriesLen);
156
let res_s = if res_c.len() == 1 {
157
res_c.new_from_index(0, ac_l.groups.len())
158
} else {
159
ListChunked::full(name, res_c.as_materialized_series(), ac_l.groups.len()).into_column()
160
};
161
ac_l.with_values(res_s, true, Some(&self.expr))?;
162
Ok(ac_l)
163
}
164
165
fn apply_group_aware<'a>(
166
&self,
167
mut ac_l: AggregationContext<'a>,
168
mut ac_r: AggregationContext<'a>,
169
) -> PolarsResult<AggregationContext<'a>> {
170
let name = self.output_field.name().clone();
171
let mut ca = ac_l
172
.iter_groups(false)
173
.zip(ac_r.iter_groups(false))
174
.map(|(l, r)| {
175
Some(apply_operator(
176
&l?.as_ref().clone().into_column(),
177
&r?.as_ref().clone().into_column(),
178
self.op,
179
))
180
})
181
.map(|opt_res| opt_res.transpose())
182
.collect::<PolarsResult<ListChunked>>()?
183
.with_name(name.clone());
184
if ca.is_empty() {
185
ca = ListChunked::full_null_with_dtype(name, 0, self.output_field.dtype());
186
}
187
188
ac_l.with_update_groups(UpdateGroups::WithSeriesLen);
189
ac_l.with_agg_state(AggState::AggregatedList(ca.into_column()));
190
Ok(ac_l)
191
}
192
}
193
194
impl PhysicalExpr for BinaryExpr {
195
fn as_expression(&self) -> Option<&Expr> {
196
Some(&self.expr)
197
}
198
199
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
200
// Window functions may set a global state that determine their output
201
// state, so we don't let them run in parallel as they race
202
// they also saturate the thread pool by themselves, so that's fine.
203
let has_window = state.has_window();
204
205
let (lhs, rhs);
206
if has_window {
207
let mut state = state.split();
208
state.remove_cache_window_flag();
209
lhs = self.left.evaluate(df, &state)?;
210
rhs = self.right.evaluate(df, &state)?;
211
} else if !self.allow_threading || self.has_literal {
212
// Literals are free, don't pay par cost.
213
lhs = self.left.evaluate(df, state)?;
214
rhs = self.right.evaluate(df, state)?;
215
} else {
216
let (opt_lhs, opt_rhs) = POOL.install(|| {
217
rayon::join(
218
|| self.left.evaluate(df, state),
219
|| self.right.evaluate(df, state),
220
)
221
});
222
(lhs, rhs) = (opt_lhs?, opt_rhs?);
223
};
224
polars_ensure!(
225
lhs.len() == rhs.len() || lhs.len() == 1 || rhs.len() == 1,
226
expr = self.expr,
227
ShapeMismatch: "cannot evaluate two Series of different lengths ({} and {})",
228
lhs.len(), rhs.len(),
229
);
230
apply_operator_owned(lhs, rhs, self.op)
231
}
232
233
#[allow(clippy::ptr_arg)]
234
fn evaluate_on_groups<'a>(
235
&self,
236
df: &DataFrame,
237
groups: &'a GroupPositions,
238
state: &ExecutionState,
239
) -> PolarsResult<AggregationContext<'a>> {
240
let (result_a, result_b) = POOL.install(|| {
241
rayon::join(
242
|| self.left.evaluate_on_groups(df, groups, state),
243
|| self.right.evaluate_on_groups(df, groups, state),
244
)
245
});
246
let mut ac_l = result_a?;
247
let ac_r = result_b?;
248
249
match (ac_l.agg_state(), ac_r.agg_state()) {
250
(AggState::LiteralScalar(s), AggState::NotAggregated(_))
251
| (AggState::NotAggregated(_), AggState::LiteralScalar(s)) => match s.len() {
252
1 => self.apply_elementwise(ac_l, ac_r, false),
253
_ => self.apply_group_aware(ac_l, ac_r),
254
},
255
(AggState::LiteralScalar(_), AggState::LiteralScalar(_)) => {
256
self.apply_all_literal(ac_l, ac_r)
257
},
258
(AggState::NotAggregated(_), AggState::NotAggregated(_)) => {
259
self.apply_elementwise(ac_l, ac_r, false)
260
},
261
(
262
AggState::AggregatedScalar(_) | AggState::LiteralScalar(_),
263
AggState::AggregatedScalar(_) | AggState::LiteralScalar(_),
264
) => self.apply_elementwise(ac_l, ac_r, true),
265
(AggState::AggregatedScalar(_), AggState::NotAggregated(_))
266
| (AggState::NotAggregated(_), AggState::AggregatedScalar(_)) => {
267
self.apply_group_aware(ac_l, ac_r)
268
},
269
(AggState::AggregatedList(lhs), AggState::AggregatedList(rhs)) => {
270
let lhs = lhs.list().unwrap();
271
let rhs = rhs.list().unwrap();
272
let out = lhs.apply_to_inner(&|lhs| {
273
apply_operator(&lhs.into_column(), &rhs.get_inner().into_column(), self.op)
274
.map(|c| c.take_materialized_series())
275
})?;
276
ac_l.with_values(out.into_column(), true, Some(&self.expr))?;
277
Ok(ac_l)
278
},
279
_ => self.apply_group_aware(ac_l, ac_r),
280
}
281
}
282
283
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
284
self.expr.to_field(input_schema)
285
}
286
287
fn is_scalar(&self) -> bool {
288
self.is_scalar
289
}
290
291
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
292
Some(self)
293
}
294
}
295
296
impl PartitionedAggregation for BinaryExpr {
297
fn evaluate_partitioned(
298
&self,
299
df: &DataFrame,
300
groups: &GroupPositions,
301
state: &ExecutionState,
302
) -> PolarsResult<Column> {
303
let left = self.left.as_partitioned_aggregator().unwrap();
304
let right = self.right.as_partitioned_aggregator().unwrap();
305
let left = left.evaluate_partitioned(df, groups, state)?;
306
let right = right.evaluate_partitioned(df, groups, state)?;
307
apply_operator(&left, &right, self.op)
308
}
309
310
fn finalize(
311
&self,
312
partitioned: Column,
313
_groups: &GroupPositions,
314
_state: &ExecutionState,
315
) -> PolarsResult<Column> {
316
Ok(partitioned)
317
}
318
}
319
320