Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/sortby.rs
6940 views
1
use polars_core::POOL;
2
use polars_core::chunked_array::from_iterator_par::ChunkedCollectParIterExt;
3
use polars_core::prelude::*;
4
use polars_utils::idx_vec::IdxVec;
5
use rayon::prelude::*;
6
7
use super::*;
8
use crate::expressions::{
9
AggregationContext, PhysicalExpr, UpdateGroups, map_sorted_indices_to_group_idx,
10
map_sorted_indices_to_group_slice,
11
};
12
13
pub struct SortByExpr {
14
pub(crate) input: Arc<dyn PhysicalExpr>,
15
pub(crate) by: Vec<Arc<dyn PhysicalExpr>>,
16
pub(crate) expr: Expr,
17
pub(crate) sort_options: SortMultipleOptions,
18
}
19
20
impl SortByExpr {
21
pub fn new(
22
input: Arc<dyn PhysicalExpr>,
23
by: Vec<Arc<dyn PhysicalExpr>>,
24
expr: Expr,
25
sort_options: SortMultipleOptions,
26
) -> Self {
27
Self {
28
input,
29
by,
30
expr,
31
sort_options,
32
}
33
}
34
}
35
36
fn prepare_bool_vec(values: &[bool], by_len: usize) -> Vec<bool> {
37
match (values.len(), by_len) {
38
// Equal length.
39
(n_rvalues, n) if n_rvalues == n => values.to_vec(),
40
// None given all false.
41
(0, n) => vec![false; n],
42
// Broadcast first.
43
(_, n) => vec![values[0]; n],
44
}
45
}
46
47
static ERR_MSG: &str = "expressions in 'sort_by' must have matching group lengths";
48
49
fn check_groups(a: &GroupsType, b: &GroupsType) -> PolarsResult<()> {
50
polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| {
51
a.len() == b.len()
52
}), ComputeError: ERR_MSG);
53
Ok(())
54
}
55
56
pub(super) fn update_groups_sort_by(
57
groups: &GroupsType,
58
sort_by_s: &Series,
59
options: &SortOptions,
60
) -> PolarsResult<GroupsType> {
61
// Will trigger a gather for every group, so rechunk before.
62
let sort_by_s = sort_by_s.rechunk();
63
let groups = POOL.install(|| {
64
groups
65
.par_iter()
66
.map(|indicator| sort_by_groups_single_by(indicator, &sort_by_s, options))
67
.collect::<PolarsResult<_>>()
68
})?;
69
70
Ok(GroupsType::Idx(groups))
71
}
72
73
fn sort_by_groups_single_by(
74
indicator: GroupsIndicator,
75
sort_by_s: &Series,
76
options: &SortOptions,
77
) -> PolarsResult<(IdxSize, IdxVec)> {
78
let options = SortOptions {
79
descending: options.descending,
80
nulls_last: options.nulls_last,
81
// We are already in par iter.
82
multithreaded: false,
83
..Default::default()
84
};
85
let new_idx = match indicator {
86
GroupsIndicator::Idx((_, idx)) => {
87
// SAFETY: group tuples are always in bounds.
88
let group = unsafe { sort_by_s.take_slice_unchecked(idx) };
89
90
let sorted_idx = group.arg_sort(options);
91
map_sorted_indices_to_group_idx(&sorted_idx, idx)
92
},
93
GroupsIndicator::Slice([first, len]) => {
94
let group = sort_by_s.slice(first as i64, len as usize);
95
let sorted_idx = group.arg_sort(options);
96
map_sorted_indices_to_group_slice(&sorted_idx, first)
97
},
98
};
99
let first = new_idx
100
.first()
101
.ok_or_else(|| polars_err!(ComputeError: "{}", ERR_MSG))?;
102
103
Ok((*first, new_idx))
104
}
105
106
fn sort_by_groups_no_match_single<'a>(
107
mut ac_in: AggregationContext<'a>,
108
mut ac_by: AggregationContext<'a>,
109
descending: bool,
110
expr: &Expr,
111
) -> PolarsResult<AggregationContext<'a>> {
112
let s_in = ac_in.aggregated();
113
let s_by = ac_by.aggregated();
114
let mut s_in = s_in.list().unwrap().clone();
115
let mut s_by = s_by.list().unwrap().clone();
116
117
let dtype = s_in.dtype().clone();
118
let ca: PolarsResult<ListChunked> = POOL.install(|| {
119
s_in.par_iter_indexed()
120
.zip(s_by.par_iter_indexed())
121
.map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) {
122
(Some(s), Some(s_sort_by)) => {
123
polars_ensure!(s.len() == s_sort_by.len(), ComputeError: "series lengths don't match in 'sort_by' expression");
124
let idx = s_sort_by.arg_sort(SortOptions {
125
descending,
126
// We are already in par iter.
127
multithreaded: false,
128
..Default::default()
129
});
130
Ok(Some(unsafe { s.take_unchecked(&idx) }))
131
},
132
_ => Ok(None),
133
})
134
.collect_ca_with_dtype(PlSmallStr::EMPTY, dtype)
135
});
136
let c = ca?.with_name(s_in.name().clone()).into_column();
137
ac_in.with_values(c, true, Some(expr))?;
138
Ok(ac_in)
139
}
140
141
fn sort_by_groups_multiple_by(
142
indicator: GroupsIndicator,
143
sort_by_s: &[Series],
144
descending: &[bool],
145
nulls_last: &[bool],
146
multithreaded: bool,
147
maintain_order: bool,
148
) -> PolarsResult<(IdxSize, IdxVec)> {
149
let new_idx = match indicator {
150
GroupsIndicator::Idx((_first, idx)) => {
151
// SAFETY: group tuples are always in bounds.
152
let groups = sort_by_s
153
.iter()
154
.map(|s| unsafe { s.take_slice_unchecked(idx) })
155
.map(Column::from)
156
.collect::<Vec<_>>();
157
158
let options = SortMultipleOptions {
159
descending: descending.to_owned(),
160
nulls_last: nulls_last.to_owned(),
161
multithreaded,
162
maintain_order,
163
limit: None,
164
};
165
166
let sorted_idx = groups[0]
167
.as_materialized_series()
168
.arg_sort_multiple(&groups[1..], &options)
169
.unwrap();
170
map_sorted_indices_to_group_idx(&sorted_idx, idx)
171
},
172
GroupsIndicator::Slice([first, len]) => {
173
let groups = sort_by_s
174
.iter()
175
.map(|s| s.slice(first as i64, len as usize))
176
.map(Column::from)
177
.collect::<Vec<_>>();
178
179
let options = SortMultipleOptions {
180
descending: descending.to_owned(),
181
nulls_last: nulls_last.to_owned(),
182
multithreaded,
183
maintain_order,
184
limit: None,
185
};
186
let sorted_idx = groups[0]
187
.as_materialized_series()
188
.arg_sort_multiple(&groups[1..], &options)
189
.unwrap();
190
map_sorted_indices_to_group_slice(&sorted_idx, first)
191
},
192
};
193
let first = new_idx
194
.first()
195
.ok_or_else(|| polars_err!(ComputeError: "{}", ERR_MSG))?;
196
197
Ok((*first, new_idx))
198
}
199
200
impl PhysicalExpr for SortByExpr {
201
fn as_expression(&self) -> Option<&Expr> {
202
Some(&self.expr)
203
}
204
205
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
206
let series_f = || self.input.evaluate(df, state);
207
if self.by.is_empty() {
208
// Sorting by 0 columns returns input unchanged.
209
return series_f();
210
}
211
let (series, sorted_idx) = if self.by.len() == 1 {
212
let sorted_idx_f = || {
213
let s_sort_by = self.by[0].evaluate(df, state)?;
214
Ok(s_sort_by.arg_sort(SortOptions::from(&self.sort_options)))
215
};
216
POOL.install(|| rayon::join(series_f, sorted_idx_f))
217
} else {
218
let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len());
219
let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());
220
221
let sorted_idx_f = || {
222
let mut needs_broadcast = false;
223
let mut broadcast_length = 1;
224
225
let mut s_sort_by = self
226
.by
227
.iter()
228
.enumerate()
229
.map(|(i, e)| {
230
let column = e.evaluate(df, state).map(|c| match c.dtype() {
231
#[cfg(feature = "dtype-categorical")]
232
DataType::Categorical(_, _) | DataType::Enum(_, _) => c,
233
_ => c.to_physical_repr(),
234
})?;
235
236
if column.len() == 1 && broadcast_length != 1 {
237
polars_ensure!(
238
e.is_scalar(),
239
ShapeMismatch: "non-scalar expression produces broadcasting column",
240
);
241
242
return Ok(column.new_from_index(0, broadcast_length));
243
}
244
245
if broadcast_length != column.len() {
246
polars_ensure!(
247
broadcast_length == 1, ShapeMismatch:
248
"`sort_by` produced different length ({}) than earlier Series' length in `by` ({})",
249
broadcast_length, column.len()
250
);
251
252
needs_broadcast |= i > 0;
253
broadcast_length = column.len();
254
}
255
256
Ok(column)
257
})
258
.collect::<PolarsResult<Vec<_>>>()?;
259
260
if needs_broadcast {
261
for c in s_sort_by.iter_mut() {
262
if c.len() != broadcast_length {
263
*c = c.new_from_index(0, broadcast_length);
264
}
265
}
266
}
267
268
let options = self
269
.sort_options
270
.clone()
271
.with_order_descending_multi(descending)
272
.with_nulls_last_multi(nulls_last);
273
274
s_sort_by[0]
275
.as_materialized_series()
276
.arg_sort_multiple(&s_sort_by[1..], &options)
277
};
278
POOL.install(|| rayon::join(series_f, sorted_idx_f))
279
};
280
let (sorted_idx, series) = (sorted_idx?, series?);
281
polars_ensure!(
282
sorted_idx.len() == series.len(),
283
expr = self.expr, ShapeMismatch:
284
"`sort_by` produced different length ({}) than the Series that has to be sorted ({})",
285
sorted_idx.len(), series.len()
286
);
287
288
// SAFETY: sorted index are within bounds.
289
unsafe { Ok(series.take_unchecked(&sorted_idx)) }
290
}
291
292
#[allow(clippy::ptr_arg)]
293
fn evaluate_on_groups<'a>(
294
&self,
295
df: &DataFrame,
296
groups: &'a GroupPositions,
297
state: &ExecutionState,
298
) -> PolarsResult<AggregationContext<'a>> {
299
let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?;
300
let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len());
301
let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());
302
303
let mut ac_sort_by = self
304
.by
305
.iter()
306
.map(|e| e.evaluate_on_groups(df, groups, state))
307
.collect::<PolarsResult<Vec<_>>>()?;
308
309
assert!(
310
ac_sort_by
311
.iter()
312
.all(|ac_sort_by| ac_sort_by.groups.len() == ac_in.groups.len())
313
);
314
315
// If every input is a LiteralScalar, we return a LiteralScalar.
316
// Otherwise, we convert any LiteralScalar to AggregatedList.
317
let all_literal = matches!(ac_in.state, AggState::LiteralScalar(_))
318
|| ac_sort_by
319
.iter()
320
.all(|ac| matches!(ac.state, AggState::LiteralScalar(_)));
321
322
if all_literal {
323
return Ok(ac_in);
324
} else {
325
if matches!(ac_in.state, AggState::LiteralScalar(_)) {
326
ac_in.aggregated();
327
}
328
for ac in ac_sort_by.iter_mut() {
329
if matches!(ac.state, AggState::LiteralScalar(_)) {
330
ac.aggregated();
331
}
332
}
333
}
334
335
let mut sort_by_s = ac_sort_by
336
.iter()
337
.map(|c| {
338
let c = c.flat_naive();
339
match c.dtype() {
340
#[cfg(feature = "dtype-categorical")]
341
DataType::Categorical(_, _) | DataType::Enum(_, _) => {
342
c.as_materialized_series().clone()
343
},
344
// @scalar-opt
345
// @partition-opt
346
_ => c.to_physical_repr().take_materialized_series(),
347
}
348
})
349
.collect::<Vec<_>>();
350
351
let ordered_by_group_operation = matches!(
352
ac_sort_by[0].update_groups,
353
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
354
);
355
356
let groups = if self.by.len() == 1 {
357
let mut ac_sort_by = ac_sort_by.pop().unwrap();
358
359
// The groups of the lhs of the expressions do not match the series values,
360
// we must take the slower path.
361
if !matches!(ac_in.update_groups, UpdateGroups::No) {
362
return sort_by_groups_no_match_single(
363
ac_in,
364
ac_sort_by,
365
self.sort_options.descending[0],
366
&self.expr,
367
);
368
};
369
370
let sort_by_s = sort_by_s.pop().unwrap();
371
let groups = ac_sort_by.groups();
372
373
let (check, groups) = POOL.join(
374
|| check_groups(groups, ac_in.groups()),
375
|| {
376
update_groups_sort_by(
377
groups,
378
&sort_by_s,
379
&SortOptions {
380
descending: descending[0],
381
nulls_last: nulls_last[0],
382
..Default::default()
383
},
384
)
385
},
386
);
387
check?;
388
389
groups?
390
} else {
391
let groups = ac_sort_by[0].groups();
392
393
let groups = POOL.install(|| {
394
groups
395
.par_iter()
396
.map(|indicator| {
397
sort_by_groups_multiple_by(
398
indicator,
399
&sort_by_s,
400
&descending,
401
&nulls_last,
402
self.sort_options.multithreaded,
403
self.sort_options.maintain_order,
404
)
405
})
406
.collect::<PolarsResult<_>>()
407
});
408
GroupsType::Idx(groups?)
409
};
410
411
// If the rhs is already aggregated once, it is reordered by the
412
// group_by operation - we must ensure that we are as well.
413
if ordered_by_group_operation {
414
let s = ac_in.aggregated();
415
ac_in.with_values(s.explode(false).unwrap(), false, None)?;
416
}
417
418
ac_in.with_groups(groups.into_sliceable());
419
Ok(ac_in)
420
}
421
422
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
423
self.input.to_field(input_schema)
424
}
425
426
fn is_scalar(&self) -> bool {
427
false
428
}
429
}
430
431