Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/sortby.rs
8422 views
1
use polars_core::POOL;
2
use polars_core::chunked_array::from_iterator_par::ChunkedCollectParIterExt;
3
use polars_core::prelude::*;
4
use polars_utils::idx_vec::IdxVec;
5
use rayon::prelude::*;
6
7
use super::*;
8
use crate::expressions::{
9
AggregationContext, PhysicalExpr, UpdateGroups, map_sorted_indices_to_group_idx,
10
map_sorted_indices_to_group_slice,
11
};
12
13
pub struct SortByExpr {
14
pub(crate) input: Arc<dyn PhysicalExpr>,
15
pub(crate) by: Vec<Arc<dyn PhysicalExpr>>,
16
pub(crate) expr: Expr,
17
pub(crate) sort_options: SortMultipleOptions,
18
}
19
20
impl SortByExpr {
21
pub fn new(
22
input: Arc<dyn PhysicalExpr>,
23
by: Vec<Arc<dyn PhysicalExpr>>,
24
expr: Expr,
25
sort_options: SortMultipleOptions,
26
) -> Self {
27
Self {
28
input,
29
by,
30
expr,
31
sort_options,
32
}
33
}
34
}
35
36
fn prepare_bool_vec(values: &[bool], by_len: usize) -> Vec<bool> {
37
match (values.len(), by_len) {
38
// Equal length.
39
(n_rvalues, n) if n_rvalues == n => values.to_vec(),
40
// None given all false.
41
(0, n) => vec![false; n],
42
// Broadcast first.
43
(_, n) => vec![values[0]; n],
44
}
45
}
46
47
static ERR_MSG: &str = "expressions in 'sort_by' must have matching group lengths";
48
49
fn check_groups(a: &GroupsType, b: &GroupsType) -> PolarsResult<()> {
50
polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| {
51
a.len() == b.len()
52
}), ShapeMismatch: ERR_MSG);
53
Ok(())
54
}
55
56
pub(super) fn update_groups_sort_by(
57
groups: &GroupsType,
58
sort_by_s: &Series,
59
options: &SortOptions,
60
) -> PolarsResult<GroupsType> {
61
// Will trigger a gather for every group, so rechunk before.
62
let sort_by_s = sort_by_s.rechunk();
63
let groups = POOL.install(|| {
64
groups
65
.par_iter()
66
.map(|indicator| sort_by_groups_single_by(indicator, &sort_by_s, options))
67
.collect::<PolarsResult<_>>()
68
})?;
69
70
Ok(GroupsType::Idx(groups))
71
}
72
73
fn sort_by_groups_single_by(
74
indicator: GroupsIndicator,
75
sort_by_s: &Series,
76
options: &SortOptions,
77
) -> PolarsResult<(IdxSize, IdxVec)> {
78
let options = SortOptions {
79
descending: options.descending,
80
nulls_last: options.nulls_last,
81
// We are already in par iter.
82
multithreaded: false,
83
..Default::default()
84
};
85
let new_idx = match indicator {
86
GroupsIndicator::Idx((_, idx)) => {
87
// SAFETY: group tuples are always in bounds.
88
let group = unsafe { sort_by_s.take_slice_unchecked(idx) };
89
90
let sorted_idx = group.arg_sort(options);
91
map_sorted_indices_to_group_idx(&sorted_idx, idx)
92
},
93
GroupsIndicator::Slice([first, len]) => {
94
let group = sort_by_s.slice(first as i64, len as usize);
95
let sorted_idx = group.arg_sort(options);
96
map_sorted_indices_to_group_slice(&sorted_idx, first)
97
},
98
};
99
100
let first = new_idx.first().unwrap_or(&0);
101
Ok((*first, new_idx))
102
}
103
104
fn sort_by_groups_no_match_single<'a>(
105
mut ac_in: AggregationContext<'a>,
106
mut ac_by: AggregationContext<'a>,
107
descending: bool,
108
expr: &Expr,
109
) -> PolarsResult<AggregationContext<'a>> {
110
let s_in = ac_in.aggregated();
111
let s_by = ac_by.aggregated();
112
let mut s_in = s_in.list().unwrap().clone();
113
let mut s_by = s_by.list().unwrap().clone();
114
115
let dtype = s_in.dtype().clone();
116
let ca: PolarsResult<ListChunked> = POOL.install(|| {
117
s_in.par_iter_indexed()
118
.zip(s_by.par_iter_indexed())
119
.map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) {
120
(Some(s), Some(s_sort_by)) => {
121
polars_ensure!(s.len() == s_sort_by.len(), ComputeError: "series lengths don't match in 'sort_by' expression");
122
let idx = s_sort_by.arg_sort(SortOptions {
123
descending,
124
// We are already in par iter.
125
multithreaded: false,
126
..Default::default()
127
});
128
Ok(Some(unsafe { s.take_unchecked(&idx) }))
129
},
130
_ => Ok(None),
131
})
132
.collect_ca_with_dtype(PlSmallStr::EMPTY, dtype)
133
});
134
let c = ca?.with_name(s_in.name().clone()).into_column();
135
ac_in.with_values(c, true, Some(expr))?;
136
Ok(ac_in)
137
}
138
139
fn sort_by_groups_multiple_by(
140
indicator: GroupsIndicator,
141
sort_by_s: &[Series],
142
descending: &[bool],
143
nulls_last: &[bool],
144
multithreaded: bool,
145
maintain_order: bool,
146
) -> PolarsResult<(IdxSize, IdxVec)> {
147
let new_idx = match indicator {
148
GroupsIndicator::Idx((_first, idx)) => {
149
// SAFETY: group tuples are always in bounds.
150
let groups = sort_by_s
151
.iter()
152
.map(|s| unsafe { s.take_slice_unchecked(idx) })
153
.map(Column::from)
154
.collect::<Vec<_>>();
155
156
let options = SortMultipleOptions {
157
descending: descending.to_owned(),
158
nulls_last: nulls_last.to_owned(),
159
multithreaded,
160
maintain_order,
161
limit: None,
162
};
163
164
let sorted_idx = groups[0]
165
.as_materialized_series()
166
.arg_sort_multiple(&groups[1..], &options)
167
.unwrap();
168
map_sorted_indices_to_group_idx(&sorted_idx, idx)
169
},
170
GroupsIndicator::Slice([first, len]) => {
171
let groups = sort_by_s
172
.iter()
173
.map(|s| s.slice(first as i64, len as usize))
174
.map(Column::from)
175
.collect::<Vec<_>>();
176
177
let options = SortMultipleOptions {
178
descending: descending.to_owned(),
179
nulls_last: nulls_last.to_owned(),
180
multithreaded,
181
maintain_order,
182
limit: None,
183
};
184
let sorted_idx = groups[0]
185
.as_materialized_series()
186
.arg_sort_multiple(&groups[1..], &options)
187
.unwrap();
188
map_sorted_indices_to_group_slice(&sorted_idx, first)
189
},
190
};
191
let first = new_idx
192
.first()
193
.ok_or_else(|| polars_err!(ComputeError: "{ERR_MSG}"))?;
194
195
Ok((*first, new_idx))
196
}
197
198
impl PhysicalExpr for SortByExpr {
199
fn as_expression(&self) -> Option<&Expr> {
200
Some(&self.expr)
201
}
202
203
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
204
let series_f = || self.input.evaluate(df, state);
205
if self.by.is_empty() {
206
// Sorting by 0 columns returns input unchanged.
207
return series_f();
208
}
209
let (series, sorted_idx) = if self.by.len() == 1 {
210
let sorted_idx_f = || {
211
let s_sort_by = self.by[0].evaluate(df, state)?;
212
Ok(s_sort_by.arg_sort(SortOptions::from(&self.sort_options)))
213
};
214
POOL.install(|| rayon::join(series_f, sorted_idx_f))
215
} else {
216
let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len());
217
let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());
218
219
let sorted_idx_f = || {
220
let mut needs_broadcast = false;
221
let mut broadcast_length = 1;
222
223
let mut s_sort_by = self
224
.by
225
.iter()
226
.enumerate()
227
.map(|(i, e)| {
228
let column = e.evaluate(df, state).map(|c| match c.dtype() {
229
#[cfg(feature = "dtype-categorical")]
230
DataType::Categorical(_, _) | DataType::Enum(_, _) => c,
231
_ => c.to_physical_repr(),
232
})?;
233
234
if column.len() == 1 && broadcast_length != 1 {
235
polars_ensure!(
236
e.is_scalar(),
237
ShapeMismatch: "non-scalar expression produces broadcasting column",
238
);
239
240
return Ok(column.new_from_index(0, broadcast_length));
241
}
242
243
if broadcast_length != column.len() {
244
polars_ensure!(
245
broadcast_length == 1, ShapeMismatch:
246
"`sort_by` produced different length ({}) than earlier Series' length in `by` ({})",
247
broadcast_length, column.len()
248
);
249
250
needs_broadcast |= i > 0;
251
broadcast_length = column.len();
252
}
253
254
Ok(column)
255
})
256
.collect::<PolarsResult<Vec<_>>>()?;
257
258
if needs_broadcast {
259
for c in s_sort_by.iter_mut() {
260
if c.len() != broadcast_length {
261
*c = c.new_from_index(0, broadcast_length);
262
}
263
}
264
}
265
266
let options = self
267
.sort_options
268
.clone()
269
.with_order_descending_multi(descending)
270
.with_nulls_last_multi(nulls_last);
271
272
s_sort_by[0]
273
.as_materialized_series()
274
.arg_sort_multiple(&s_sort_by[1..], &options)
275
};
276
POOL.install(|| rayon::join(series_f, sorted_idx_f))
277
};
278
let (sorted_idx, series) = (sorted_idx?, series?);
279
polars_ensure!(
280
sorted_idx.len() == series.len(),
281
expr = self.expr, ShapeMismatch:
282
"`sort_by` produced different length ({}) than the Series that has to be sorted ({})",
283
sorted_idx.len(), series.len()
284
);
285
286
// SAFETY: sorted index are within bounds.
287
unsafe { Ok(series.take_unchecked(&sorted_idx)) }
288
}
289
290
#[allow(clippy::ptr_arg)]
291
fn evaluate_on_groups<'a>(
292
&self,
293
df: &DataFrame,
294
groups: &'a GroupPositions,
295
state: &ExecutionState,
296
) -> PolarsResult<AggregationContext<'a>> {
297
let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?;
298
let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len());
299
let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());
300
301
let mut ac_sort_by = self
302
.by
303
.iter()
304
.map(|e| e.evaluate_on_groups(df, groups, state))
305
.collect::<PolarsResult<Vec<_>>>()?;
306
307
assert!(
308
ac_sort_by
309
.iter()
310
.all(|ac_sort_by| ac_sort_by.groups.len() == ac_in.groups.len())
311
);
312
313
// Enable reliable length checks downstream
314
ac_in.set_groups_for_undefined_agg_states();
315
ac_sort_by
316
.iter_mut()
317
.for_each(|ac| ac.set_groups_for_undefined_agg_states());
318
319
// If every input is a LiteralScalar, we return a LiteralScalar.
320
// Otherwise, we convert any LiteralScalar to AggregatedList.
321
let all_literal = matches!(ac_in.state, AggState::LiteralScalar(_))
322
|| ac_sort_by
323
.iter()
324
.all(|ac| matches!(ac.state, AggState::LiteralScalar(_)));
325
326
if all_literal {
327
return Ok(ac_in);
328
} else {
329
if matches!(ac_in.state, AggState::LiteralScalar(_)) {
330
ac_in.aggregated();
331
}
332
for ac in ac_sort_by.iter_mut() {
333
if matches!(ac.state, AggState::LiteralScalar(_)) {
334
ac.aggregated();
335
}
336
}
337
}
338
339
let mut sort_by_s = ac_sort_by
340
.iter()
341
.map(|c| {
342
let c = c.flat_naive();
343
match c.dtype() {
344
#[cfg(feature = "dtype-categorical")]
345
DataType::Categorical(_, _) | DataType::Enum(_, _) => {
346
c.as_materialized_series().clone()
347
},
348
// @scalar-opt
349
// @partition-opt
350
_ => c.to_physical_repr().take_materialized_series(),
351
}
352
})
353
.collect::<Vec<_>>();
354
355
let ordered_by_group_operation = matches!(
356
ac_sort_by[0].update_groups,
357
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
358
);
359
360
let groups = if self.by.len() == 1 {
361
let mut ac_sort_by = ac_sort_by.pop().unwrap();
362
363
// The groups of the lhs of the expressions do not match the series values,
364
// we must take the slower path.
365
if !matches!(ac_in.update_groups, UpdateGroups::No) {
366
return sort_by_groups_no_match_single(
367
ac_in,
368
ac_sort_by,
369
self.sort_options.descending[0],
370
&self.expr,
371
);
372
};
373
374
let sort_by_s = sort_by_s.pop().unwrap();
375
let groups = ac_sort_by.groups();
376
377
let (check, groups) = POOL.join(
378
|| check_groups(groups, ac_in.groups()),
379
|| {
380
update_groups_sort_by(
381
groups,
382
&sort_by_s,
383
&SortOptions {
384
descending: descending[0],
385
nulls_last: nulls_last[0],
386
..Default::default()
387
},
388
)
389
},
390
);
391
check?;
392
393
groups?
394
} else {
395
let groups = ac_sort_by[0].groups();
396
397
let groups = POOL.install(|| {
398
groups
399
.par_iter()
400
.map(|indicator| {
401
sort_by_groups_multiple_by(
402
indicator,
403
&sort_by_s,
404
&descending,
405
&nulls_last,
406
self.sort_options.multithreaded,
407
self.sort_options.maintain_order,
408
)
409
})
410
.collect::<PolarsResult<_>>()
411
});
412
GroupsType::Idx(groups?)
413
};
414
415
// If the rhs is already aggregated once, it is reordered by the
416
// group_by operation - we must ensure that we are as well.
417
if ordered_by_group_operation {
418
let s = ac_in.aggregated();
419
ac_in.with_values(
420
s.explode(ExplodeOptions {
421
empty_as_null: true,
422
keep_nulls: true,
423
})
424
.unwrap(),
425
false,
426
None,
427
)?;
428
}
429
430
ac_in.with_groups(groups.into_sliceable());
431
Ok(ac_in)
432
}
433
434
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
435
self.input.to_field(input_schema)
436
}
437
438
fn is_scalar(&self) -> bool {
439
false
440
}
441
}
442
443