Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/ternary.rs
6940 views
1
use polars_core::POOL;
2
use polars_core::prelude::*;
3
use polars_plan::prelude::*;
4
5
use super::*;
6
use crate::expressions::{AggregationContext, PhysicalExpr};
7
8
pub struct TernaryExpr {
9
predicate: Arc<dyn PhysicalExpr>,
10
truthy: Arc<dyn PhysicalExpr>,
11
falsy: Arc<dyn PhysicalExpr>,
12
expr: Expr,
13
// Can be expensive on small data to run literals in parallel.
14
run_par: bool,
15
returns_scalar: bool,
16
}
17
18
impl TernaryExpr {
19
pub fn new(
20
predicate: Arc<dyn PhysicalExpr>,
21
truthy: Arc<dyn PhysicalExpr>,
22
falsy: Arc<dyn PhysicalExpr>,
23
expr: Expr,
24
run_par: bool,
25
returns_scalar: bool,
26
) -> Self {
27
Self {
28
predicate,
29
truthy,
30
falsy,
31
expr,
32
run_par,
33
returns_scalar,
34
}
35
}
36
}
37
38
fn finish_as_iters<'a>(
39
mut ac_truthy: AggregationContext<'a>,
40
mut ac_falsy: AggregationContext<'a>,
41
mut ac_mask: AggregationContext<'a>,
42
) -> PolarsResult<AggregationContext<'a>> {
43
let ca = ac_truthy
44
.iter_groups(false)
45
.zip(ac_falsy.iter_groups(false))
46
.zip(ac_mask.iter_groups(false))
47
.map(|((truthy, falsy), mask)| {
48
match (truthy, falsy, mask) {
49
(Some(truthy), Some(falsy), Some(mask)) => Some(
50
truthy
51
.as_ref()
52
.zip_with(mask.as_ref().bool()?, falsy.as_ref()),
53
),
54
_ => None,
55
}
56
.transpose()
57
})
58
.collect::<PolarsResult<ListChunked>>()?
59
.with_name(ac_truthy.get_values().name().clone());
60
61
// Aggregation leaves only a single chunk.
62
let arr = ca.downcast_iter().next().unwrap();
63
let list_vals_len = arr.values().len();
64
65
let mut out = ca.into_column();
66
if ac_truthy.arity_should_explode() && ac_falsy.arity_should_explode() && ac_mask.arity_should_explode() &&
67
// Exploded list should be equal to groups length.
68
list_vals_len == ac_truthy.groups.len()
69
{
70
out = out.explode(false)?
71
}
72
73
ac_truthy.with_agg_state(AggState::AggregatedList(out));
74
ac_truthy.with_update_groups(UpdateGroups::WithSeriesLen);
75
76
Ok(ac_truthy)
77
}
78
79
impl PhysicalExpr for TernaryExpr {
80
fn as_expression(&self) -> Option<&Expr> {
81
Some(&self.expr)
82
}
83
84
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
85
let mut state = state.split();
86
// Don't cache window functions as they run in parallel.
87
state.remove_cache_window_flag();
88
let mask_series = self.predicate.evaluate(df, &state)?;
89
let mask = mask_series.bool()?.clone();
90
91
let op_truthy = || self.truthy.evaluate(df, &state);
92
let op_falsy = || self.falsy.evaluate(df, &state);
93
let (truthy, falsy) = if self.run_par {
94
POOL.install(|| rayon::join(op_truthy, op_falsy))
95
} else {
96
(op_truthy(), op_falsy())
97
};
98
let truthy = truthy?;
99
let falsy = falsy?;
100
101
truthy.zip_with(&mask, &falsy)
102
}
103
104
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
105
self.truthy.to_field(input_schema)
106
}
107
108
#[allow(clippy::ptr_arg)]
109
fn evaluate_on_groups<'a>(
110
&self,
111
df: &DataFrame,
112
groups: &'a GroupPositions,
113
state: &ExecutionState,
114
) -> PolarsResult<AggregationContext<'a>> {
115
let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);
116
let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);
117
let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state);
118
let (ac_mask, (ac_truthy, ac_falsy)) = if self.run_par {
119
POOL.install(|| rayon::join(op_mask, || rayon::join(op_truthy, op_falsy)))
120
} else {
121
(op_mask(), (op_truthy(), op_falsy()))
122
};
123
124
let mut ac_mask = ac_mask?;
125
let mut ac_truthy = ac_truthy?;
126
let mut ac_falsy = ac_falsy?;
127
128
use AggState::*;
129
130
// Check if there are any:
131
// - non-unit literals
132
// - AggregatedScalar or AggregatedList
133
let mut has_non_unit_literal = false;
134
let mut has_aggregated = false;
135
// If the length has changed then we must not apply on the flat values
136
// as ternary broadcasting is length-sensitive.
137
let mut non_aggregated_len_modified = false;
138
139
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
140
match ac.agg_state() {
141
LiteralScalar(s) => {
142
has_non_unit_literal = s.len() != 1;
143
144
if has_non_unit_literal {
145
break;
146
}
147
},
148
NotAggregated(_) => {
149
non_aggregated_len_modified |= !ac.original_len;
150
},
151
AggregatedScalar(_) | AggregatedList(_) => {
152
has_aggregated = true;
153
},
154
}
155
}
156
157
if has_non_unit_literal {
158
// finish_as_iters for non-unit literals to avoid materializing the
159
// literal inputs per-group.
160
if state.verbose() {
161
eprintln!("ternary agg: finish as iters due to non-unit literal")
162
}
163
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
164
}
165
166
if !has_aggregated && !non_aggregated_len_modified {
167
// Everything is flat (either NotAggregated or a unit literal).
168
if state.verbose() {
169
eprintln!("ternary agg: finish all not-aggregated or unit literal");
170
}
171
172
let out = ac_truthy
173
.get_values()
174
.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;
175
176
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
177
if matches!(ac.agg_state(), NotAggregated(_)) {
178
let ac_target = ac;
179
180
return Ok(AggregationContext {
181
state: NotAggregated(out),
182
groups: ac_target.groups.clone(),
183
update_groups: ac_target.update_groups,
184
original_len: ac_target.original_len,
185
});
186
}
187
}
188
189
ac_truthy.with_agg_state(LiteralScalar(out));
190
191
return Ok(ac_truthy);
192
}
193
194
for ac in [&mut ac_mask, &mut ac_truthy, &mut ac_falsy].into_iter() {
195
if matches!(ac.agg_state(), NotAggregated(_)) {
196
let _ = ac.aggregated();
197
}
198
}
199
200
// At this point the input agg states are one of the following:
201
// * `Literal` where `s.len() == 1`
202
// * `AggregatedList`
203
// * `AggregatedScalar`
204
205
let mut non_literal_acs = Vec::<&AggregationContext>::with_capacity(3);
206
207
// non_literal_acs will have at least 1 item because has_aggregated was
208
// true from above.
209
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
210
if !matches!(ac.agg_state(), LiteralScalar(_)) {
211
non_literal_acs.push(ac);
212
}
213
}
214
215
for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {
216
if std::mem::discriminant(ac_l.agg_state()) != std::mem::discriminant(ac_r.agg_state())
217
{
218
// Mix of AggregatedScalar and AggregatedList is done per group,
219
// as every row of the AggregatedScalar must be broadcasted to a
220
// list of the same length as the corresponding AggregatedList
221
// row.
222
if state.verbose() {
223
eprintln!(
224
"ternary agg: finish as iters due to mix of AggregatedScalar and AggregatedList"
225
)
226
}
227
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
228
}
229
}
230
231
// At this point, the possible combinations are:
232
// * mix of unit literals and AggregatedScalar
233
// * `zip_with` can be called directly with the series
234
// * mix of unit literals and AggregatedList
235
// * `zip_with` can be called with the flat values after the offsets
236
// have been checked for alignment
237
let ac_target = non_literal_acs.first().unwrap();
238
239
let agg_state_out = match ac_target.agg_state() {
240
AggregatedList(_) => {
241
// Ternary can be applied directly on the flattened series,
242
// given that their offsets have been checked to be equal.
243
if state.verbose() {
244
eprintln!("ternary agg: finish AggregatedList")
245
}
246
247
for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {
248
match (ac_l.agg_state(), ac_r.agg_state()) {
249
(AggregatedList(s_l), AggregatedList(s_r)) => {
250
let check = s_l.list().unwrap().offsets()?.as_slice()
251
== s_r.list().unwrap().offsets()?.as_slice();
252
253
polars_ensure!(
254
check,
255
ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation"
256
);
257
},
258
_ => unreachable!(),
259
}
260
}
261
262
let truthy = if let AggregatedList(s) = ac_truthy.agg_state() {
263
s.list().unwrap().get_inner().into_column()
264
} else {
265
ac_truthy.get_values().clone()
266
};
267
268
let falsy = if let AggregatedList(s) = ac_falsy.agg_state() {
269
s.list().unwrap().get_inner().into_column()
270
} else {
271
ac_falsy.get_values().clone()
272
};
273
274
let mask = if let AggregatedList(s) = ac_mask.agg_state() {
275
s.list().unwrap().get_inner().into_column()
276
} else {
277
ac_mask.get_values().clone()
278
};
279
280
let out = truthy.zip_with(mask.bool()?, &falsy)?;
281
282
// The output series is guaranteed to be aligned with expected
283
// offsets buffer of the result, so we construct the result
284
// ListChunked directly from the 2.
285
let out = out.rechunk();
286
// @scalar-opt
287
// @partition-opt
288
let values = out.as_materialized_series().array_ref(0);
289
let offsets = ac_target.get_values().list().unwrap().offsets()?;
290
let inner_type = out.dtype();
291
let dtype = LargeListArray::default_datatype(values.dtype().clone());
292
293
// SAFETY: offsets are correct.
294
let out = LargeListArray::new(dtype, offsets, values.clone(), None);
295
296
let mut out = ListChunked::with_chunk(truthy.name().clone(), out);
297
unsafe { out.to_logical(inner_type.clone()) };
298
299
if ac_target.get_values().list().unwrap()._can_fast_explode() {
300
out.set_fast_explode();
301
};
302
303
let out = out.into_column();
304
305
AggregatedList(out)
306
},
307
AggregatedScalar(_) => {
308
if state.verbose() {
309
eprintln!("ternary agg: finish AggregatedScalar")
310
}
311
312
let out = ac_truthy
313
.get_values()
314
.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;
315
AggregatedScalar(out)
316
},
317
_ => {
318
unreachable!()
319
},
320
};
321
322
Ok(AggregationContext {
323
state: agg_state_out,
324
groups: ac_target.groups.clone(),
325
update_groups: ac_target.update_groups,
326
original_len: ac_target.original_len,
327
})
328
}
329
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
330
Some(self)
331
}
332
333
fn is_scalar(&self) -> bool {
334
self.returns_scalar
335
}
336
}
337
338
impl PartitionedAggregation for TernaryExpr {
339
fn evaluate_partitioned(
340
&self,
341
df: &DataFrame,
342
groups: &GroupPositions,
343
state: &ExecutionState,
344
) -> PolarsResult<Column> {
345
let truthy = self.truthy.as_partitioned_aggregator().unwrap();
346
let falsy = self.falsy.as_partitioned_aggregator().unwrap();
347
let mask = self.predicate.as_partitioned_aggregator().unwrap();
348
349
let truthy = truthy.evaluate_partitioned(df, groups, state)?;
350
let falsy = falsy.evaluate_partitioned(df, groups, state)?;
351
let mask = mask.evaluate_partitioned(df, groups, state)?;
352
let mask = mask.bool()?.clone();
353
354
truthy.zip_with(&mask, &falsy)
355
}
356
357
fn finalize(
358
&self,
359
partitioned: Column,
360
_groups: &GroupPositions,
361
_state: &ExecutionState,
362
) -> PolarsResult<Column> {
363
Ok(partitioned)
364
}
365
}
366
367