Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/ternary.rs
8422 views
1
use polars_core::POOL;
2
use polars_core::prelude::*;
3
use polars_plan::prelude::*;
4
5
use super::*;
6
use crate::expressions::{AggregationContext, PhysicalExpr};
7
8
pub struct TernaryExpr {
9
predicate: Arc<dyn PhysicalExpr>,
10
truthy: Arc<dyn PhysicalExpr>,
11
falsy: Arc<dyn PhysicalExpr>,
12
expr: Expr,
13
// Can be expensive on small data to run literals in parallel.
14
run_par: bool,
15
returns_scalar: bool,
16
}
17
18
impl TernaryExpr {
19
pub fn new(
20
predicate: Arc<dyn PhysicalExpr>,
21
truthy: Arc<dyn PhysicalExpr>,
22
falsy: Arc<dyn PhysicalExpr>,
23
expr: Expr,
24
run_par: bool,
25
returns_scalar: bool,
26
) -> Self {
27
Self {
28
predicate,
29
truthy,
30
falsy,
31
expr,
32
run_par,
33
returns_scalar,
34
}
35
}
36
}
37
38
fn finish_as_iters<'a>(
39
mut ac_truthy: AggregationContext<'a>,
40
mut ac_falsy: AggregationContext<'a>,
41
mut ac_mask: AggregationContext<'a>,
42
) -> PolarsResult<AggregationContext<'a>> {
43
let ca = ac_truthy
44
.iter_groups(false)
45
.zip(ac_falsy.iter_groups(false))
46
.zip(ac_mask.iter_groups(false))
47
.map(|((truthy, falsy), mask)| {
48
match (truthy, falsy, mask) {
49
(Some(truthy), Some(falsy), Some(mask)) => Some(
50
truthy
51
.as_ref()
52
.zip_with(mask.as_ref().bool()?, falsy.as_ref()),
53
),
54
_ => None,
55
}
56
.transpose()
57
})
58
.collect::<PolarsResult<ListChunked>>()?
59
.with_name(ac_truthy.get_values().name().clone());
60
61
// Aggregation leaves only a single chunk.
62
let arr = ca.downcast_iter().next().unwrap();
63
let list_vals_len = arr.values().len();
64
65
let mut out = ca.into_column();
66
if ac_truthy.arity_should_explode() && ac_falsy.arity_should_explode() && ac_mask.arity_should_explode() &&
67
// Exploded list should be equal to groups length.
68
list_vals_len == ac_truthy.groups.len()
69
{
70
out = out.explode(ExplodeOptions {
71
empty_as_null: true,
72
keep_nulls: true,
73
})?
74
}
75
76
ac_truthy.with_agg_state(AggState::AggregatedList(out));
77
ac_truthy.with_update_groups(UpdateGroups::WithSeriesLen);
78
79
Ok(ac_truthy)
80
}
81
82
impl PhysicalExpr for TernaryExpr {
83
fn as_expression(&self) -> Option<&Expr> {
84
Some(&self.expr)
85
}
86
87
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
88
let mut state = state.split();
89
// Don't cache window functions as they run in parallel.
90
state.remove_cache_window_flag();
91
let mask_series = self.predicate.evaluate(df, &state)?;
92
let mask = mask_series.bool()?.clone();
93
94
let op_truthy = || self.truthy.evaluate(df, &state);
95
let op_falsy = || self.falsy.evaluate(df, &state);
96
let (truthy, falsy) = if self.run_par {
97
POOL.install(|| rayon::join(op_truthy, op_falsy))
98
} else {
99
(op_truthy(), op_falsy())
100
};
101
let truthy = truthy?;
102
let falsy = falsy?;
103
104
truthy.zip_with(&mask, &falsy)
105
}
106
107
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
108
self.truthy.to_field(input_schema)
109
}
110
111
#[allow(clippy::ptr_arg)]
112
fn evaluate_on_groups<'a>(
113
&self,
114
df: &DataFrame,
115
groups: &'a GroupPositions,
116
state: &ExecutionState,
117
) -> PolarsResult<AggregationContext<'a>> {
118
let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);
119
let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);
120
let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state);
121
let (ac_mask, (ac_truthy, ac_falsy)) = if self.run_par {
122
POOL.install(|| rayon::join(op_mask, || rayon::join(op_truthy, op_falsy)))
123
} else {
124
(op_mask(), (op_truthy(), op_falsy()))
125
};
126
127
let mut ac_mask = ac_mask?;
128
let mut ac_truthy = ac_truthy?;
129
let mut ac_falsy = ac_falsy?;
130
131
use AggState::*;
132
133
// Check if there are any:
134
// - non-unit literals
135
// - AggregatedScalar or AggregatedList
136
let mut has_non_unit_literal = false;
137
let mut has_aggregated = false;
138
// If the length has changed then we must not apply on the flat values
139
// as ternary broadcasting is length-sensitive.
140
let mut non_aggregated_len_modified = false;
141
142
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
143
match ac.agg_state() {
144
LiteralScalar(s) => {
145
has_non_unit_literal = s.len() != 1;
146
147
if has_non_unit_literal {
148
break;
149
}
150
},
151
NotAggregated(_) => {
152
non_aggregated_len_modified |= !ac.original_len;
153
},
154
AggregatedScalar(_) | AggregatedList(_) => {
155
has_aggregated = true;
156
},
157
}
158
}
159
160
if has_non_unit_literal {
161
// finish_as_iters for non-unit literals to avoid materializing the
162
// literal inputs per-group.
163
if state.verbose() {
164
eprintln!("ternary agg: finish as iters due to non-unit literal")
165
}
166
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
167
}
168
169
if !has_aggregated && !non_aggregated_len_modified {
170
// Everything is flat (either NotAggregated or a unit literal).
171
if state.verbose() {
172
eprintln!("ternary agg: finish all not-aggregated or unit literal");
173
}
174
175
let out = ac_truthy
176
.get_values()
177
.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;
178
179
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
180
if matches!(ac.agg_state(), NotAggregated(_)) {
181
let ac_target = ac;
182
183
return Ok(AggregationContext {
184
state: NotAggregated(out),
185
groups: ac_target.groups.clone(),
186
update_groups: ac_target.update_groups,
187
original_len: ac_target.original_len,
188
});
189
}
190
}
191
192
ac_truthy.with_agg_state(LiteralScalar(out));
193
194
return Ok(ac_truthy);
195
}
196
197
for ac in [&mut ac_mask, &mut ac_truthy, &mut ac_falsy].into_iter() {
198
if matches!(ac.agg_state(), NotAggregated(_)) {
199
let _ = ac.aggregated();
200
}
201
}
202
203
// At this point the input agg states are one of the following:
204
// * `Literal` where `s.len() == 1`
205
// * `AggregatedList`
206
// * `AggregatedScalar`
207
208
let mut non_literal_acs = Vec::<&AggregationContext>::with_capacity(3);
209
210
// non_literal_acs will have at least 1 item because has_aggregated was
211
// true from above.
212
for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {
213
if !matches!(ac.agg_state(), LiteralScalar(_)) {
214
non_literal_acs.push(ac);
215
}
216
}
217
218
for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {
219
if std::mem::discriminant(ac_l.agg_state()) != std::mem::discriminant(ac_r.agg_state())
220
{
221
// Mix of AggregatedScalar and AggregatedList is done per group,
222
// as every row of the AggregatedScalar must be broadcasted to a
223
// list of the same length as the corresponding AggregatedList
224
// row.
225
if state.verbose() {
226
eprintln!(
227
"ternary agg: finish as iters due to mix of AggregatedScalar and AggregatedList"
228
)
229
}
230
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
231
}
232
}
233
234
// At this point, the possible combinations are:
235
// * mix of unit literals and AggregatedScalar
236
// * `zip_with` can be called directly with the series
237
// * mix of unit literals and AggregatedList
238
// * `zip_with` can be called with the flat values after the offsets
239
// have been checked for alignment
240
let ac_target = non_literal_acs.first().unwrap();
241
242
let agg_state_out = match ac_target.agg_state() {
243
AggregatedList(_) => {
244
// Ternary can be applied directly on the flattened series,
245
// given that their offsets have been checked to be equal.
246
if state.verbose() {
247
eprintln!("ternary agg: finish AggregatedList")
248
}
249
250
for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {
251
match (ac_l.agg_state(), ac_r.agg_state()) {
252
(AggregatedList(s_l), AggregatedList(s_r)) => {
253
let check = s_l.list().unwrap().offsets()?.as_slice()
254
== s_r.list().unwrap().offsets()?.as_slice();
255
256
polars_ensure!(
257
check,
258
ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation"
259
);
260
},
261
_ => unreachable!(),
262
}
263
}
264
265
let truthy = if let AggregatedList(s) = ac_truthy.agg_state() {
266
s.list().unwrap().get_inner().into_column()
267
} else {
268
ac_truthy.get_values().clone()
269
};
270
271
let falsy = if let AggregatedList(s) = ac_falsy.agg_state() {
272
s.list().unwrap().get_inner().into_column()
273
} else {
274
ac_falsy.get_values().clone()
275
};
276
277
let mask = if let AggregatedList(s) = ac_mask.agg_state() {
278
s.list().unwrap().get_inner().into_column()
279
} else {
280
ac_mask.get_values().clone()
281
};
282
283
let out = truthy.zip_with(mask.bool()?, &falsy)?;
284
285
// The output series is guaranteed to be aligned with expected
286
// offsets buffer of the result, so we construct the result
287
// ListChunked directly from the 2.
288
let out = out.rechunk();
289
// @scalar-opt
290
// @partition-opt
291
let values = out.as_materialized_series().array_ref(0);
292
let offsets = ac_target.get_values().list().unwrap().offsets()?;
293
let inner_type = out.dtype();
294
let dtype = LargeListArray::default_datatype(values.dtype().clone());
295
296
// SAFETY: offsets are correct.
297
let out = LargeListArray::new(dtype, offsets, values.clone(), None);
298
299
let mut out = ListChunked::with_chunk(truthy.name().clone(), out);
300
unsafe { out.to_logical(inner_type.clone()) };
301
302
if ac_target.get_values().list().unwrap()._can_fast_explode() {
303
out.set_fast_explode();
304
};
305
306
let out = out.into_column();
307
308
AggregatedList(out)
309
},
310
AggregatedScalar(_) => {
311
if state.verbose() {
312
eprintln!("ternary agg: finish AggregatedScalar")
313
}
314
315
let out = ac_truthy
316
.get_values()
317
.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;
318
AggregatedScalar(out)
319
},
320
_ => {
321
unreachable!()
322
},
323
};
324
325
Ok(AggregationContext {
326
state: agg_state_out,
327
groups: ac_target.groups.clone(),
328
update_groups: ac_target.update_groups,
329
original_len: ac_target.original_len,
330
})
331
}
332
333
fn is_scalar(&self) -> bool {
334
self.returns_scalar
335
}
336
}
337
338