Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/binary.rs
8406 views
1
use polars_core::POOL;
2
use polars_core::prelude::*;
3
#[cfg(feature = "round_series")]
4
use polars_ops::prelude::floor_div_series;
5
6
use super::*;
7
use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
8
9
#[derive(Clone)]
10
pub struct BinaryExpr {
11
left: Arc<dyn PhysicalExpr>,
12
op: Operator,
13
right: Arc<dyn PhysicalExpr>,
14
expr: Expr,
15
has_literal: bool,
16
allow_threading: bool,
17
is_scalar: bool,
18
output_field: Field,
19
}
20
21
impl BinaryExpr {
22
#[expect(clippy::too_many_arguments)]
23
pub fn new(
24
left: Arc<dyn PhysicalExpr>,
25
op: Operator,
26
right: Arc<dyn PhysicalExpr>,
27
expr: Expr,
28
has_literal: bool,
29
allow_threading: bool,
30
is_scalar: bool,
31
output_field: Field,
32
) -> Self {
33
Self {
34
left,
35
op,
36
right,
37
expr,
38
has_literal,
39
allow_threading,
40
is_scalar,
41
output_field,
42
}
43
}
44
}
45
46
/// Can partially do operations in place.
47
fn apply_operator_owned(left: Column, right: Column, op: Operator) -> PolarsResult<Column> {
48
match op {
49
Operator::Plus => left.try_add_owned(right),
50
Operator::Minus => left.try_sub_owned(right),
51
Operator::Multiply
52
if left.dtype().is_primitive_numeric() && right.dtype().is_primitive_numeric() =>
53
{
54
left.try_mul_owned(right)
55
},
56
_ => apply_operator(&left, &right, op),
57
}
58
}
59
60
pub fn apply_operator(left: &Column, right: &Column, op: Operator) -> PolarsResult<Column> {
61
use DataType::*;
62
match op {
63
Operator::Gt => ChunkCompareIneq::gt(left, right).map(|ca| ca.into_column()),
64
Operator::GtEq => ChunkCompareIneq::gt_eq(left, right).map(|ca| ca.into_column()),
65
Operator::Lt => ChunkCompareIneq::lt(left, right).map(|ca| ca.into_column()),
66
Operator::LtEq => ChunkCompareIneq::lt_eq(left, right).map(|ca| ca.into_column()),
67
Operator::Eq => ChunkCompareEq::equal(left, right).map(|ca| ca.into_column()),
68
Operator::NotEq => ChunkCompareEq::not_equal(left, right).map(|ca| ca.into_column()),
69
Operator::Plus => left + right,
70
Operator::Minus => left - right,
71
Operator::Multiply => left * right,
72
Operator::RustDivide => left / right,
73
Operator::TrueDivide => match left.dtype() {
74
#[cfg(feature = "dtype-decimal")]
75
Decimal(_, _) => left / right,
76
#[cfg(feature = "dtype-f16")]
77
Float16 => left / right,
78
Duration(_) | Date | Datetime(_, _) | Float32 | Float64 => left / right,
79
#[cfg(feature = "dtype-array")]
80
Array(..) => left / right,
81
#[cfg(feature = "dtype-array")]
82
_ if right.dtype().is_array() => left / right,
83
List(_) => left / right,
84
_ if right.dtype().is_list() => left / right,
85
_ if left.dtype().is_string() || right.dtype().is_string() => {
86
polars_bail!(InvalidOperation: "cannot divide using strings")
87
},
88
_ => {
89
if right.dtype().is_temporal() {
90
return left / right;
91
}
92
left.cast(&Float64)? / right.cast(&Float64)?
93
},
94
},
95
Operator::FloorDivide => {
96
#[cfg(feature = "round_series")]
97
{
98
floor_div_series(
99
left.as_materialized_series(),
100
right.as_materialized_series(),
101
)
102
.map(Column::from)
103
}
104
#[cfg(not(feature = "round_series"))]
105
{
106
panic!("activate 'round_series' feature")
107
}
108
},
109
Operator::And => left.bitand(right),
110
Operator::Or => left.bitor(right),
111
Operator::LogicalOr => left
112
.cast(&DataType::Boolean)?
113
.bitor(&right.cast(&DataType::Boolean)?),
114
Operator::LogicalAnd => left
115
.cast(&DataType::Boolean)?
116
.bitand(&right.cast(&DataType::Boolean)?),
117
Operator::Xor => left.bitxor(right),
118
Operator::Modulus => left % right,
119
Operator::EqValidity => left.equal_missing(right).map(|ca| ca.into_column()),
120
Operator::NotEqValidity => left.not_equal_missing(right).map(|ca| ca.into_column()),
121
}
122
}
123
124
impl BinaryExpr {
125
fn apply_elementwise<'a>(
126
&self,
127
mut ac_l: AggregationContext<'a>,
128
mut ac_r: AggregationContext<'a>,
129
aggregated: bool,
130
) -> PolarsResult<AggregationContext<'a>> {
131
// At this stage, there is no combination of AggregatedList and NotAggregated ACs.
132
133
// Check group lengths in case of all AggList
134
if [&ac_l, &ac_r]
135
.iter()
136
.all(|ac| matches!(ac.state, AggState::AggregatedList(_)))
137
{
138
ac_l.groups().check_lengths(ac_r.groups())?;
139
}
140
141
match (ac_l.agg_state(), ac_r.agg_state()) {
142
(AggState::AggregatedList(s), _) | (_, AggState::AggregatedList(s)) => {
143
let ca = s.list().unwrap();
144
let [col_l, col_r] = [&ac_l, &ac_r].map(|ac| ac.flat_naive().into_owned());
145
146
let out = ca.apply_to_inner(&|_| {
147
apply_operator(&col_l, &col_r, self.op).map(|c| c.take_materialized_series())
148
})?;
149
let out = out.into_column();
150
151
if ac_l.is_literal() {
152
std::mem::swap(&mut ac_l, &mut ac_r);
153
}
154
155
ac_l.with_values(out.into_column(), true, Some(&self.expr))?;
156
Ok(ac_l)
157
},
158
159
_ => {
160
// We want to be able to mutate in place, so we take the lhs to make sure that we drop.
161
let lhs = ac_l.get_values().clone();
162
let rhs = ac_r.get_values().clone();
163
164
let out = apply_operator_owned(lhs, rhs, self.op)?;
165
166
if ac_l.is_literal() {
167
std::mem::swap(&mut ac_l, &mut ac_r);
168
}
169
170
// Drop lhs so that we might operate in place.
171
drop(ac_l.take());
172
173
ac_l.with_values(out, aggregated, Some(&self.expr))?;
174
Ok(ac_l)
175
},
176
}
177
}
178
179
fn apply_all_literal<'a>(
180
&self,
181
mut ac_l: AggregationContext<'a>,
182
ac_r: AggregationContext<'a>,
183
) -> PolarsResult<AggregationContext<'a>> {
184
debug_assert!(ac_l.is_literal() && ac_r.is_literal());
185
polars_ensure!(ac_l.groups.len() == ac_r.groups.len(),
186
ComputeError: "lhs and rhs should have same number of groups");
187
188
let left_c = ac_l.get_values().rechunk().into_column();
189
let right_c = ac_r.get_values().rechunk().into_column();
190
let res_c = apply_operator(&left_c, &right_c, self.op)?;
191
polars_ensure!(res_c.len() == 1,
192
ComputeError: "binary operation on literals expected 1 value, found {}", res_c.len());
193
194
ac_l.with_literal(res_c);
195
Ok(ac_l)
196
}
197
198
fn apply_group_aware<'a>(
199
&self,
200
mut ac_l: AggregationContext<'a>,
201
mut ac_r: AggregationContext<'a>,
202
) -> PolarsResult<AggregationContext<'a>> {
203
let name = self.output_field.name().clone();
204
let mut ca = ac_l
205
.iter_groups(false)
206
.zip(ac_r.iter_groups(false))
207
.map(|(l, r)| {
208
Some(apply_operator(
209
&l?.as_ref().clone().into_column(),
210
&r?.as_ref().clone().into_column(),
211
self.op,
212
))
213
})
214
.map(|opt_res| opt_res.transpose())
215
.collect::<PolarsResult<ListChunked>>()?
216
.with_name(name.clone());
217
if ca.is_empty() {
218
ca = ListChunked::full_null_with_dtype(name, 0, self.output_field.dtype());
219
}
220
221
ac_l.with_update_groups(UpdateGroups::WithSeriesLen);
222
ac_l.with_agg_state(AggState::AggregatedList(ca.into_column()));
223
Ok(ac_l)
224
}
225
}
226
227
impl PhysicalExpr for BinaryExpr {
228
fn as_expression(&self) -> Option<&Expr> {
229
Some(&self.expr)
230
}
231
232
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
233
// Window functions may set a global state that determine their output
234
// state, so we don't let them run in parallel as they race
235
// they also saturate the thread pool by themselves, so that's fine.
236
let has_window = state.has_window();
237
238
let (lhs, rhs);
239
if has_window {
240
let mut state = state.split();
241
state.remove_cache_window_flag();
242
lhs = self.left.evaluate(df, &state)?;
243
rhs = self.right.evaluate(df, &state)?;
244
} else if !self.allow_threading || self.has_literal {
245
// Literals are free, don't pay par cost.
246
lhs = self.left.evaluate(df, state)?;
247
rhs = self.right.evaluate(df, state)?;
248
} else {
249
let (opt_lhs, opt_rhs) = POOL.install(|| {
250
rayon::join(
251
|| self.left.evaluate(df, state),
252
|| self.right.evaluate(df, state),
253
)
254
});
255
(lhs, rhs) = (opt_lhs?, opt_rhs?);
256
};
257
polars_ensure!(
258
lhs.len() == rhs.len() || lhs.len() == 1 || rhs.len() == 1,
259
expr = self.expr,
260
ShapeMismatch: "cannot evaluate two Series of different lengths ({} and {})",
261
lhs.len(), rhs.len(),
262
);
263
apply_operator_owned(lhs, rhs, self.op)
264
}
265
266
#[allow(clippy::ptr_arg)]
267
fn evaluate_on_groups<'a>(
268
&self,
269
df: &DataFrame,
270
groups: &'a GroupPositions,
271
state: &ExecutionState,
272
) -> PolarsResult<AggregationContext<'a>> {
273
let (result_a, result_b) = POOL.install(|| {
274
rayon::join(
275
|| self.left.evaluate_on_groups(df, groups, state),
276
|| self.right.evaluate_on_groups(df, groups, state),
277
)
278
});
279
let mut ac_l = result_a?;
280
let mut ac_r = result_b?;
281
282
// Aggregate NotAggregated into AggregatedList, but only if strictly required AND
283
// when there is no risk of memory explosion.
284
// See ApplyExpr for additional context
285
let mut has_agg_list = false;
286
let mut has_agg_scalar = false;
287
let mut has_not_agg = false;
288
let mut has_not_agg_with_overlapping_groups = false;
289
let mut not_agg_groups_may_diverge = false;
290
291
let mut previous: Option<&AggregationContext<'_>> = None;
292
for ac in [&ac_l, &ac_r] {
293
match ac.state {
294
AggState::AggregatedList(_) => {
295
has_agg_list = true;
296
},
297
AggState::AggregatedScalar(_) => has_agg_scalar = true,
298
AggState::NotAggregated(_) => {
299
has_not_agg = true;
300
if let Some(p) = previous {
301
not_agg_groups_may_diverge |= !p.groups.is_same(&ac.groups)
302
}
303
previous = Some(ac);
304
if ac.groups.is_overlapping() {
305
has_not_agg_with_overlapping_groups = true;
306
}
307
},
308
_ => {},
309
}
310
}
311
312
let all_literal = !(has_agg_list || has_agg_scalar || has_not_agg);
313
let elementwise_must_aggregate =
314
has_not_agg && (has_agg_list || not_agg_groups_may_diverge);
315
let mut aggregated = has_agg_list || has_agg_scalar;
316
317
// Arithmetic on Decimal is fallible
318
let has_decimal_dtype =
319
ac_l.get_values().dtype().is_decimal() || ac_r.get_values().dtype().is_decimal();
320
let is_fallible = has_decimal_dtype && self.op.is_arithmetic();
321
322
// Broadcast in NotAgg or AggList requires group_aware
323
let check_broadcast = [&ac_l, &ac_r].iter().all(|ac| {
324
matches!(
325
ac.agg_state(),
326
AggState::NotAggregated(_) | AggState::AggregatedList(_)
327
)
328
});
329
let has_broadcast = check_broadcast
330
&& ac_l
331
.groups()
332
.iter()
333
.zip(ac_r.groups().iter())
334
.any(|(l, r)| l.len() != r.len() && (l.len() == 1 || r.len() == 1));
335
336
// Dispatch
337
// See ApplyExpr for reference logic, except that we do any required
338
// aggregation inline. All BinaryExprs are elementwise.
339
if all_literal {
340
// Fast path
341
self.apply_all_literal(ac_l, ac_r)
342
} else if has_agg_scalar && (has_agg_list || has_not_agg) {
343
// Not compatible
344
self.apply_group_aware(ac_l, ac_r)
345
} else if elementwise_must_aggregate && has_not_agg_with_overlapping_groups {
346
// Compatible but calling aggregated() is too expensive
347
self.apply_group_aware(ac_l, ac_r)
348
} else if is_fallible
349
&& (!ac_l.groups_cover_all_values() || !ac_r.groups_cover_all_values())
350
{
351
// Fallible expression and there are elements that are masked out.
352
self.apply_group_aware(ac_l, ac_r)
353
} else {
354
if elementwise_must_aggregate {
355
for ac in [&mut ac_l, &mut ac_r] {
356
if matches!(ac.state, AggState::NotAggregated(_)) {
357
ac.aggregated();
358
}
359
}
360
aggregated = true;
361
}
362
if has_broadcast {
363
self.apply_group_aware(ac_l, ac_r)
364
} else {
365
self.apply_elementwise(ac_l, ac_r, aggregated)
366
}
367
}
368
}
369
370
fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
371
Ok(self.output_field.clone())
372
}
373
374
fn is_scalar(&self) -> bool {
375
self.is_scalar
376
}
377
}
378
379