Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/expressions/window.rs
8421 views
1
use std::cmp::Ordering;
2
use std::fmt::Write;
3
4
use arrow::array::PrimitiveArray;
5
use arrow::bitmap::Bitmap;
6
use arrow::trusted_len::TrustMyLength;
7
use polars_core::error::feature_gated;
8
use polars_core::prelude::row_encode::encode_rows_unordered;
9
use polars_core::prelude::sort::perfect_sort;
10
use polars_core::prelude::*;
11
use polars_core::series::IsSorted;
12
use polars_core::utils::_split_offsets;
13
use polars_core::{POOL, downcast_as_macro_arg_physical};
14
use polars_ops::frame::SeriesJoin;
15
use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys};
16
use polars_ops::prelude::*;
17
use polars_plan::prelude::*;
18
use polars_utils::UnitVec;
19
use polars_utils::sync::SyncPtr;
20
use polars_utils::vec::PushUnchecked;
21
use rayon::prelude::*;
22
23
use super::*;
24
25
pub struct WindowExpr {
26
/// the root column that the Function will be applied on.
27
/// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index
28
pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,
29
pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,
30
pub(crate) apply_columns: Vec<PlSmallStr>,
31
pub(crate) phys_function: Arc<dyn PhysicalExpr>,
32
pub(crate) mapping: WindowMapping,
33
pub(crate) expr: Expr,
34
pub(crate) has_different_group_sources: bool,
35
pub(crate) output_field: Field,
36
37
pub(crate) all_group_by_are_elementwise: bool,
38
pub(crate) order_by_is_elementwise: bool,
39
}
40
41
#[cfg_attr(debug_assertions, derive(Debug))]
42
enum MapStrategy {
43
// Join by key, this the most expensive
44
// for reduced aggregations
45
Join,
46
// explode now
47
Explode,
48
// Use an arg_sort to map the values back
49
Map,
50
Nothing,
51
}
52
53
impl WindowExpr {
54
fn map_list_agg_by_arg_sort(
55
&self,
56
out_column: Column,
57
flattened: &Column,
58
mut ac: AggregationContext,
59
gb: GroupBy,
60
) -> PolarsResult<IdxCa> {
61
// idx (new-idx, original-idx)
62
let mut idx_mapping = Vec::with_capacity(out_column.len());
63
64
// we already set this buffer so we can reuse the `original_idx` buffer
65
// that saves an allocation
66
let mut take_idx = vec![];
67
68
// groups are not changed, we can map by doing a standard arg_sort.
69
if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) {
70
let mut iter = 0..flattened.len() as IdxSize;
71
match ac.groups().as_ref().as_ref() {
72
GroupsType::Idx(groups) => {
73
for g in groups.all() {
74
idx_mapping.extend(g.iter().copied().zip(&mut iter));
75
}
76
},
77
GroupsType::Slice { groups, .. } => {
78
for &[first, len] in groups {
79
idx_mapping.extend((first..first + len).zip(&mut iter));
80
}
81
},
82
}
83
}
84
// groups are changed, we use the new group indexes as arguments of the arg_sort
85
// and sort by the old indexes
86
else {
87
let mut original_idx = Vec::with_capacity(out_column.len());
88
match gb.get_groups().as_ref() {
89
GroupsType::Idx(groups) => {
90
for g in groups.all() {
91
original_idx.extend_from_slice(g)
92
}
93
},
94
GroupsType::Slice { groups, .. } => {
95
for &[first, len] in groups {
96
original_idx.extend(first..first + len)
97
}
98
},
99
};
100
101
let mut original_idx_iter = original_idx.iter().copied();
102
103
match ac.groups().as_ref().as_ref() {
104
GroupsType::Idx(groups) => {
105
for g in groups.all() {
106
idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter));
107
}
108
},
109
GroupsType::Slice { groups, .. } => {
110
for &[first, len] in groups {
111
idx_mapping.extend((first..first + len).zip(&mut original_idx_iter));
112
}
113
},
114
}
115
original_idx.clear();
116
take_idx = original_idx;
117
}
118
// SAFETY:
119
// we only have unique indices ranging from 0..len
120
unsafe { perfect_sort(&idx_mapping, &mut take_idx) };
121
Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx))
122
}
123
124
#[allow(clippy::too_many_arguments)]
125
fn map_by_arg_sort(
126
&self,
127
df: &DataFrame,
128
out_column: Column,
129
flattened: &Column,
130
mut ac: AggregationContext,
131
group_by_columns: &[Column],
132
gb: GroupBy,
133
cache_key: String,
134
state: &ExecutionState,
135
) -> PolarsResult<Column> {
136
// we use an arg_sort to map the values back
137
138
// This is a bit more complicated because the final group tuples may differ from the original
139
// so we use the original indices as idx values to arg_sort the original column
140
//
141
// The example below shows the naive version without group tuple mapping
142
143
// columns
144
// a b a a
145
//
146
// agg list
147
// [0, 2, 3]
148
// [1]
149
//
150
// flatten
151
//
152
// [0, 2, 3, 1]
153
//
154
// arg_sort
155
//
156
// [0, 3, 1, 2]
157
//
158
// take by arg_sorted indexes and voila groups mapped
159
// [0, 1, 2, 3]
160
161
if flattened.len() != df.height() {
162
let ca = out_column.list().unwrap();
163
let non_matching_group =
164
ca.into_iter()
165
.zip(ac.groups().iter())
166
.find(|(output, group)| {
167
if let Some(output) = output {
168
output.as_ref().len() != group.len()
169
} else {
170
false
171
}
172
});
173
174
if let Some((output, group)) = non_matching_group {
175
let first = group.first();
176
let group = group_by_columns
177
.iter()
178
.map(|s| format!("{}", s.get(first as usize).unwrap()))
179
.collect::<Vec<_>>();
180
polars_bail!(
181
expr = self.expr, ShapeMismatch:
182
"the length of the window expression did not match that of the group\
183
\n> group: {}\n> group length: {}\n> output: '{:?}'",
184
comma_delimited(String::new(), &group), group.len(), output.unwrap()
185
);
186
} else {
187
polars_bail!(
188
expr = self.expr, ShapeMismatch:
189
"the length of the window expression did not match that of the group"
190
);
191
};
192
}
193
194
let idx = if state.cache_window() {
195
if let Some(idx) = state.window_cache.get_map(&cache_key) {
196
idx
197
} else {
198
let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?);
199
state.window_cache.insert_map(cache_key, idx.clone());
200
idx
201
}
202
} else {
203
Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?)
204
};
205
206
// SAFETY:
207
// groups should always be in bounds.
208
unsafe { Ok(flattened.take_unchecked(&idx)) }
209
}
210
211
fn run_aggregation<'a>(
212
&self,
213
df: &DataFrame,
214
state: &ExecutionState,
215
gb: &'a GroupBy,
216
) -> PolarsResult<AggregationContext<'a>> {
217
let ac = self
218
.phys_function
219
.evaluate_on_groups(df, gb.get_groups(), state)?;
220
Ok(ac)
221
}
222
223
fn is_explicit_list_agg(&self) -> bool {
224
// col("foo").implode()
225
// col("foo").implode().alias()
226
// ..
227
// col("foo").implode().alias().alias()
228
//
229
// but not:
230
// col("foo").implode().sum().alias()
231
// ..
232
// col("foo").min()
233
let mut explicit_list = false;
234
for e in &self.expr {
235
if let Expr::Over { function, .. } = e {
236
// or list().alias
237
let mut finishes_list = false;
238
for e in &**function {
239
match e {
240
Expr::Agg(AggExpr::Implode(_)) => {
241
finishes_list = true;
242
},
243
Expr::Alias(_, _) => {},
244
_ => break,
245
}
246
}
247
explicit_list = finishes_list;
248
}
249
}
250
251
explicit_list
252
}
253
254
fn is_simple_column_expr(&self) -> bool {
255
// col()
256
// or col().alias()
257
let mut simple_col = false;
258
for e in &self.expr {
259
if let Expr::Over { function, .. } = e {
260
// or list().alias
261
for e in &**function {
262
match e {
263
Expr::Column(_) => {
264
simple_col = true;
265
},
266
Expr::Alias(_, _) => {},
267
_ => break,
268
}
269
}
270
}
271
}
272
simple_col
273
}
274
275
fn is_aggregation(&self) -> bool {
276
// col()
277
// or col().agg()
278
let mut agg_col = false;
279
for e in &self.expr {
280
if let Expr::Over { function, .. } = e {
281
// or list().alias
282
for e in &**function {
283
match e {
284
Expr::Agg(_) => {
285
agg_col = true;
286
},
287
Expr::Alias(_, _) => {},
288
_ => break,
289
}
290
}
291
}
292
}
293
agg_col
294
}
295
296
fn determine_map_strategy(
297
&self,
298
ac: &mut AggregationContext,
299
gb: &GroupBy,
300
) -> PolarsResult<MapStrategy> {
301
match (self.mapping, ac.agg_state()) {
302
// Explode
303
// `(col("x").sum() * col("y")).list().over("groups").flatten()`
304
(WindowMapping::Explode, _) => Ok(MapStrategy::Explode),
305
// // explicit list
306
// // `(col("x").sum() * col("y")).list().over("groups")`
307
// (false, false, _) => Ok(MapStrategy::Join),
308
// aggregations
309
//`sum("foo").over("groups")`
310
(_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join),
311
// no explicit aggregations, map over the groups
312
//`(col("x").sum() * col("y")).over("groups")`
313
(WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join),
314
// no explicit aggregations, map over the groups
315
//`(col("x").sum() * col("y")).over("groups")`
316
(WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {
317
if let GroupsType::Slice { .. } = gb.get_groups().as_ref() {
318
// Result can be directly exploded if the input was sorted.
319
ac.groups().as_ref().check_lengths(gb.get_groups())?;
320
Ok(MapStrategy::Explode)
321
} else {
322
Ok(MapStrategy::Map)
323
}
324
},
325
// no aggregations, just return column
326
// or an aggregation that has been flattened
327
// we have to check which one
328
//`col("foo").over("groups")`
329
(WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => {
330
// col()
331
// or col().alias()
332
if self.is_simple_column_expr() {
333
Ok(MapStrategy::Nothing)
334
} else {
335
Ok(MapStrategy::Map)
336
}
337
},
338
(WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join),
339
// literals, do nothing and let broadcast
340
(_, AggState::LiteralScalar(_)) => Ok(MapStrategy::Nothing),
341
}
342
}
343
}
344
345
// Utility to create partitions and cache keys
346
pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {
347
write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();
348
}
349
350
impl PhysicalExpr for WindowExpr {
351
// Note: this was first implemented with expression evaluation but this performed really bad.
352
// Therefore we choose the group_by -> apply -> self join approach
353
354
// This first cached the group_by and the join tuples, but rayon under a mutex leads to deadlocks:
355
// https://github.com/rayon-rs/rayon/issues/592
356
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
357
// This method does the following:
358
// 1. determine group_by tuples based on the group_column
359
// 2. apply an aggregation function
360
// 3. join the results back to the original dataframe
361
// this stores all group values on the original df size
362
//
363
// we have several strategies for this
364
// - 3.1 JOIN
365
// Use a join for aggregations like
366
// `sum("foo").over("groups")`
367
// and explicit `list` aggregations
368
// `(col("x").sum() * col("y")).list().over("groups")`
369
//
370
// - 3.2 EXPLODE
371
// Explicit list aggregations that are followed by `over().flatten()`
372
// # the fastest method to do things over groups when the groups are sorted.
373
// # note that it will require an explicit `list()` call from now on.
374
// `(col("x").sum() * col("y")).list().over("groups").flatten()`
375
//
376
// - 3.3. MAP to original locations
377
// This will be done for list aggregations that are not explicitly aggregated as list
378
// `(col("x").sum() * col("y")).over("groups")
379
// This can be used to reverse, sort, shuffle etc. the values in a group
380
381
// 4. select the final column and return
382
383
if df.height() == 0 {
384
let field = self.phys_function.to_field(df.schema())?;
385
match self.mapping {
386
WindowMapping::Join => {
387
return Ok(Column::full_null(
388
field.name().clone(),
389
0,
390
&DataType::List(Box::new(field.dtype().clone())),
391
));
392
},
393
_ => {
394
return Ok(Column::full_null(field.name().clone(), 0, field.dtype()));
395
},
396
}
397
}
398
399
let mut group_by_columns = self
400
.group_by
401
.iter()
402
.map(|e| e.evaluate(df, state))
403
.collect::<PolarsResult<Vec<_>>>()?;
404
405
// if the keys are sorted
406
let sorted_keys = group_by_columns.iter().all(|s| {
407
matches!(
408
s.is_sorted_flag(),
409
IsSorted::Ascending | IsSorted::Descending
410
)
411
});
412
let explicit_list_agg = self.is_explicit_list_agg();
413
414
// if we flatten this column we need to make sure the groups are sorted.
415
let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) ||
416
// if not
417
// `col().over()`
418
// and not
419
// `col().list().over`
420
// and not
421
// `col().sum()`
422
// and keys are sorted
423
// we may optimize with explode call
424
(!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());
425
426
// overwrite sort_groups for some expressions
427
// TODO: fully understand the rationale is here.
428
if self.has_different_group_sources {
429
sort_groups = true
430
}
431
432
let create_groups = || {
433
let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;
434
let mut groups = gb.into_groups();
435
436
if let Some((order_by, options)) = &self.order_by {
437
let order_by = order_by.evaluate(df, state)?;
438
polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());
439
groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)?
440
.into_sliceable()
441
}
442
443
let out: PolarsResult<GroupPositions> = Ok(groups);
444
out
445
};
446
447
// Try to get cached grouptuples
448
let (mut groups, cache_key) = if state.cache_window() {
449
let mut cache_key = String::with_capacity(32 * group_by_columns.len());
450
write!(&mut cache_key, "{}", state.branch_idx).unwrap();
451
for s in &group_by_columns {
452
cache_key.push_str(s.name());
453
}
454
if let Some((e, options)) = &self.order_by {
455
let e = match e.as_expression() {
456
Some(e) => e,
457
None => {
458
polars_bail!(InvalidOperation: "cannot order by this expression in window function")
459
},
460
};
461
window_function_format_order_by(&mut cache_key, e, options)
462
}
463
464
let groups = match state.window_cache.get_groups(&cache_key) {
465
Some(groups) => groups,
466
None => create_groups()?,
467
};
468
(groups, cache_key)
469
} else {
470
(create_groups()?, "".to_string())
471
};
472
473
// 2. create GroupBy object and apply aggregation
474
let apply_columns = self.apply_columns.clone();
475
476
// some window expressions need sorted groups
477
// to make sure that the caches align we sort
478
// the groups, so that the cached groups and join keys
479
// are consistent among all windows
480
if sort_groups || state.cache_window() {
481
groups.sort();
482
state
483
.window_cache
484
.insert_groups(cache_key.clone(), groups.clone());
485
}
486
487
// broadcast if required
488
for col in group_by_columns.iter_mut() {
489
if col.len() != df.height() {
490
polars_ensure!(
491
col.len() == 1,
492
ShapeMismatch: "columns used as `partition_by` must have the same length as the DataFrame"
493
);
494
*col = col.new_from_index(0, df.height())
495
}
496
}
497
498
let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns));
499
500
let mut ac = self.run_aggregation(df, state, &gb)?;
501
502
use MapStrategy::*;
503
504
match self.determine_map_strategy(&mut ac, &gb)? {
505
Nothing => {
506
let mut out = ac.flat_naive().into_owned();
507
508
if ac.is_literal() {
509
out = out.new_from_index(0, df.height())
510
}
511
Ok(out.into_column())
512
},
513
Explode => {
514
let out = if self.phys_function.is_scalar() {
515
ac.get_values().clone()
516
} else {
517
ac.aggregated().explode(ExplodeOptions {
518
empty_as_null: true,
519
keep_nulls: true,
520
})?
521
};
522
Ok(out.into_column())
523
},
524
Map => {
525
// TODO!
526
// investigate if sorted arrays can be return directly
527
let out_column = ac.aggregated();
528
let flattened = out_column.explode(ExplodeOptions {
529
empty_as_null: true,
530
keep_nulls: true,
531
})?;
532
// we extend the lifetime as we must convince the compiler that ac lives
533
// long enough. We drop `GrouBy` when we are done with `ac`.
534
let ac = unsafe {
535
std::mem::transmute::<AggregationContext<'_>, AggregationContext<'static>>(ac)
536
};
537
self.map_by_arg_sort(
538
df,
539
out_column,
540
&flattened,
541
ac,
542
&group_by_columns,
543
gb,
544
cache_key,
545
state,
546
)
547
},
548
Join => {
549
let out_column = ac.aggregated();
550
// we try to flatten/extend the array by repeating the aggregated value n times
551
// where n is the number of members in that group. That way we can try to reuse
552
// the same map by arg_sort logic as done for listed aggregations
553
let update_groups = !matches!(&ac.update_groups, UpdateGroups::No);
554
match (
555
&ac.update_groups,
556
set_by_groups(&out_column, &ac, df.height(), update_groups),
557
) {
558
// for aggregations that reduce like sum, mean, first and are numeric
559
// we take the group locations to directly map them to the right place
560
(UpdateGroups::No, Some(out)) => Ok(out.into_column()),
561
(_, _) => {
562
let keys = gb.keys();
563
564
let get_join_tuples = || {
565
if group_by_columns.len() == 1 {
566
let mut left = group_by_columns[0].clone();
567
// group key from right column
568
let mut right = keys[0].clone();
569
570
let (left, right) = if left.dtype().is_nested() {
571
(
572
ChunkedArray::<BinaryOffsetType>::with_chunk(
573
"".into(),
574
row_encode::_get_rows_encoded_unordered(&[
575
left.clone()
576
])?
577
.into_array(),
578
)
579
.into_series(),
580
ChunkedArray::<BinaryOffsetType>::with_chunk(
581
"".into(),
582
row_encode::_get_rows_encoded_unordered(&[
583
right.clone()
584
])?
585
.into_array(),
586
)
587
.into_series(),
588
)
589
} else {
590
(
591
left.into_materialized_series().clone(),
592
right.into_materialized_series().clone(),
593
)
594
};
595
596
PolarsResult::Ok(Arc::new(
597
left.hash_join_left(&right, JoinValidation::ManyToMany, true)
598
.unwrap()
599
.1,
600
))
601
} else {
602
let df_right =
603
unsafe { DataFrame::new_unchecked_infer_height(keys) };
604
let df_left = unsafe {
605
DataFrame::new_unchecked_infer_height(group_by_columns)
606
};
607
Ok(Arc::new(
608
private_left_join_multiple_keys(&df_left, &df_right, true)?.1,
609
))
610
}
611
};
612
613
// try to get cached join_tuples
614
let join_opt_ids = if state.cache_window() {
615
if let Some(jt) = state.window_cache.get_join(&cache_key) {
616
jt
617
} else {
618
let jt = get_join_tuples()?;
619
state.window_cache.insert_join(cache_key, jt.clone());
620
jt
621
}
622
} else {
623
get_join_tuples()?
624
};
625
626
let out = materialize_column(&join_opt_ids, &out_column);
627
Ok(out.into_column())
628
},
629
}
630
},
631
}
632
}
633
634
fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
635
Ok(self.output_field.clone())
636
}
637
638
fn is_scalar(&self) -> bool {
639
false
640
}
641
642
#[allow(clippy::ptr_arg)]
643
fn evaluate_on_groups<'a>(
644
&self,
645
df: &DataFrame,
646
groups: &'a GroupPositions,
647
state: &ExecutionState,
648
) -> PolarsResult<AggregationContext<'a>> {
649
if self.group_by.is_empty()
650
|| !self.all_group_by_are_elementwise
651
|| (self.order_by.is_some() && !self.order_by_is_elementwise)
652
{
653
polars_bail!(
654
InvalidOperation:
655
"window expression with non-elementwise `partition_by` or `order_by` not allowed in aggregation context"
656
);
657
}
658
659
let length_preserving_height = if let Some((c, _)) = state.element.as_ref() {
660
c.len()
661
} else {
662
df.height()
663
};
664
665
let function_is_scalar = self.phys_function.is_scalar();
666
let needs_remap_to_rows =
667
matches!(self.mapping, WindowMapping::GroupsToRows) && !function_is_scalar;
668
669
let partition_by_columns = self
670
.group_by
671
.iter()
672
.map(|e| {
673
let mut e = e.evaluate(df, state)?;
674
if e.len() == 1 {
675
e = e.new_from_index(0, length_preserving_height);
676
}
677
// Sanity check: Length Preserving.
678
assert_eq!(e.len(), length_preserving_height,);
679
Ok(e)
680
})
681
.collect::<PolarsResult<Vec<_>>>()?;
682
let order_by = match &self.order_by {
683
None => None,
684
Some((e, options)) => {
685
let mut e = e.evaluate(df, state)?;
686
if e.len() == 1 {
687
e = e.new_from_index(0, length_preserving_height);
688
}
689
// Sanity check: Length Preserving.
690
assert_eq!(e.len(), length_preserving_height);
691
let arr: Option<PrimitiveArray<IdxSize>> = if needs_remap_to_rows {
692
feature_gated!("rank", {
693
// Performance: precompute the rank here, so we can avoid dispatching per group
694
// later.
695
use polars_ops::series::SeriesRank;
696
let arr = e.as_materialized_series().rank(
697
RankOptions {
698
method: RankMethod::Ordinal,
699
descending: false,
700
},
701
None,
702
);
703
let arr = arr.idx()?;
704
let arr = arr.rechunk();
705
Some(arr.downcast_as_array().clone())
706
})
707
} else {
708
None
709
};
710
711
Some((e.clone(), arr, *options))
712
},
713
};
714
715
let (num_unique_ids, unique_ids) = if partition_by_columns.len() == 1 {
716
partition_by_columns[0].unique_id()?
717
} else {
718
ChunkUnique::unique_id(&encode_rows_unordered(&partition_by_columns)?)?
719
};
720
721
// All the groups within the existing groups.
722
let subgroups_approx_capacity = groups.len();
723
let mut subgroups: Vec<(IdxSize, UnitVec<IdxSize>)> =
724
Vec::with_capacity(subgroups_approx_capacity);
725
726
// Indices for the output groups. Not used with `WindowMapping::Explode`.
727
let mut gather_indices_offset = 0;
728
let mut gather_indices: Vec<(IdxSize, UnitVec<IdxSize>)> =
729
Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
730
0
731
} else {
732
groups.len()
733
});
734
// Slices for the output groups. Only used with `WindowMapping::Explode`.
735
let mut strategy_explode_groups: Vec<[IdxSize; 2]> =
736
Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
737
groups.len()
738
} else {
739
0
740
});
741
742
// Amortized vectors to reorder based on `order_by`.
743
let mut amort_arg_sort = Vec::new();
744
let mut amort_offsets = Vec::new();
745
746
// Amortized vectors to gather per group data.
747
let mut amort_subgroups_order = Vec::with_capacity(num_unique_ids as usize);
748
let mut amort_subgroups_sizes = Vec::with_capacity(num_unique_ids as usize);
749
let mut amort_subgroups_indices = (0..num_unique_ids)
750
.map(|_| (0, UnitVec::new()))
751
.collect::<Vec<(IdxSize, UnitVec<IdxSize>)>>();
752
753
macro_rules! map_window_groups {
754
($iter:expr, $get:expr) => {
755
let mut subgroup_gather_indices =
756
UnitVec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
757
0
758
} else {
759
$iter.len()
760
});
761
762
amort_subgroups_order.clear();
763
amort_subgroups_sizes.clear();
764
amort_subgroups_sizes.resize(num_unique_ids as usize, 0);
765
766
// Determine sizes per subgroup.
767
for i in $iter.clone() {
768
let id = *unsafe { unique_ids.get_unchecked(i as usize) };
769
let size = unsafe { amort_subgroups_sizes.get_unchecked_mut(id as usize) };
770
if *size == 0 {
771
unsafe { amort_subgroups_order.push_unchecked(id) };
772
}
773
*size += 1;
774
}
775
776
if matches!(self.mapping, WindowMapping::Explode) {
777
strategy_explode_groups.push([
778
subgroups.len() as IdxSize,
779
amort_subgroups_order.len() as IdxSize,
780
]);
781
}
782
783
// Set starting gather indices and reserve capacity per subgroup.
784
let mut offset = if needs_remap_to_rows {
785
gather_indices_offset
786
} else {
787
subgroups.len() as IdxSize
788
};
789
for &id in &amort_subgroups_order {
790
let size = *unsafe { amort_subgroups_sizes.get_unchecked(id as usize) };
791
let (next_gather_idx, indices) =
792
unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
793
indices.reserve(size as usize);
794
*next_gather_idx = offset;
795
offset += if needs_remap_to_rows { size } else { 1 };
796
}
797
798
// Collect gather indices.
799
if matches!(self.mapping, WindowMapping::Explode) {
800
for i in $iter {
801
let id = *unsafe { unique_ids.get_unchecked(i as usize) };
802
let (_, indices) =
803
unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
804
unsafe { indices.push_unchecked(i) };
805
}
806
} else {
807
// If we are remapping exploded rows back to rows and are reordering, we need
808
// to ensure we reorder the gather indices as well. Reordering the `subgroup`
809
// indices is done later.
810
//
811
// We having precalculated both the `unique_ids` and `order_by_ranks` in
812
// efficient kernels, we can now relatively efficient arg_sort per group. This
813
// is still horrendously slow, but at least not as bad as it would be if you
814
// did this naively.
815
if needs_remap_to_rows && let Some((_, arr, options)) = &order_by {
816
let arr = arr.as_ref().unwrap();
817
amort_arg_sort.clear();
818
amort_arg_sort.extend(0..$iter.len() as IdxSize);
819
match arr.validity() {
820
None => {
821
let arr = arr.values().as_slice();
822
amort_arg_sort.sort_by(|a, b| {
823
let in_group_idx_a = $get(*a as usize) as usize;
824
let in_group_idx_b = $get(*b as usize) as usize;
825
826
let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
827
let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
828
829
let mut cmp = order_a.cmp(&order_b);
830
// Performance: This can generally be handled branchlessly.
831
if options.descending {
832
cmp = cmp.reverse();
833
}
834
cmp
835
});
836
},
837
Some(validity) => {
838
let arr = arr.values().as_slice();
839
amort_arg_sort.sort_by(|a, b| {
840
let in_group_idx_a = $get(*a as usize) as usize;
841
let in_group_idx_b = $get(*b as usize) as usize;
842
843
let is_valid_a =
844
unsafe { validity.get_bit_unchecked(in_group_idx_a) };
845
let is_valid_b =
846
unsafe { validity.get_bit_unchecked(in_group_idx_b) };
847
let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
848
let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
849
850
if !is_valid_a & !is_valid_b {
851
return Ordering::Equal;
852
}
853
854
let mut cmp = order_a.cmp(&order_b);
855
if !is_valid_a {
856
cmp = Ordering::Less;
857
}
858
if !is_valid_b {
859
cmp = Ordering::Greater;
860
}
861
if options.descending
862
| ((!is_valid_a | !is_valid_b) & options.nulls_last)
863
{
864
cmp = cmp.reverse();
865
}
866
cmp
867
});
868
},
869
}
870
871
amort_offsets.clear();
872
amort_offsets.resize($iter.len(), 0);
873
for &id in &amort_subgroups_order {
874
amort_subgroups_sizes[id as usize] = 0;
875
}
876
877
for &idx in &amort_arg_sort {
878
let in_group_idx = $get(idx as usize);
879
let id = *unsafe { unique_ids.get_unchecked(in_group_idx as usize) };
880
amort_offsets[idx as usize] = amort_subgroups_sizes[id as usize];
881
amort_subgroups_sizes[id as usize] += 1;
882
}
883
884
for (i, offset) in $iter.zip(&amort_offsets) {
885
let id = *unsafe { unique_ids.get_unchecked(i as usize) };
886
let (next_gather_idx, indices) =
887
unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
888
unsafe {
889
subgroup_gather_indices.push_unchecked(*next_gather_idx + *offset)
890
};
891
unsafe { indices.push_unchecked(i) };
892
}
893
} else {
894
for i in $iter {
895
let id = *unsafe { unique_ids.get_unchecked(i as usize) };
896
let (next_gather_idx, indices) =
897
unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
898
unsafe { subgroup_gather_indices.push_unchecked(*next_gather_idx) };
899
*next_gather_idx += IdxSize::from(needs_remap_to_rows);
900
unsafe { indices.push_unchecked(i) };
901
}
902
}
903
}
904
905
// Push groups into nested_groups.
906
subgroups.extend(amort_subgroups_order.iter().map(|&id| {
907
let (_, indices) =
908
unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
909
let indices = std::mem::take(indices);
910
(*unsafe { indices.get_unchecked(0) }, indices)
911
}));
912
913
if !matches!(self.mapping, WindowMapping::Explode) {
914
gather_indices_offset += subgroup_gather_indices.len() as IdxSize;
915
gather_indices.push((
916
subgroup_gather_indices.first().copied().unwrap_or(0),
917
subgroup_gather_indices,
918
));
919
}
920
};
921
}
922
match groups.as_ref() {
923
GroupsType::Idx(idxs) => {
924
for g in idxs.all() {
925
map_window_groups!(g.iter().copied(), (|i: usize| g[i]));
926
}
927
},
928
GroupsType::Slice {
929
groups,
930
overlapping: _,
931
monotonic: _,
932
} => {
933
for [s, l] in groups.iter() {
934
let s = *s;
935
let l = *l;
936
let iter = unsafe { TrustMyLength::new(s..s + l, l as usize) };
937
map_window_groups!(iter, (|i: usize| s + i as IdxSize));
938
}
939
},
940
}
941
942
let mut subgroups = GroupsType::Idx(subgroups.into());
943
if let Some((order_by, _, options)) = order_by {
944
subgroups =
945
update_groups_sort_by(&subgroups, order_by.as_materialized_series(), &options)?;
946
}
947
let subgroups = subgroups.into_sliceable();
948
let mut data = self
949
.phys_function
950
.evaluate_on_groups(df, &subgroups, state)?
951
.finalize();
952
953
let final_groups = if matches!(self.mapping, WindowMapping::Explode) {
954
if !function_is_scalar {
955
let (data_s, offsets) = data.list()?.explode_and_offsets(ExplodeOptions {
956
empty_as_null: false,
957
keep_nulls: false,
958
})?;
959
data = data_s.into_column();
960
961
let mut exploded_offset = 0;
962
for [start, length] in strategy_explode_groups.iter_mut() {
963
let exploded_start = exploded_offset;
964
let exploded_length = offsets
965
.lengths()
966
.skip(*start as usize)
967
.take(*length as usize)
968
.sum::<usize>() as IdxSize;
969
exploded_offset += exploded_length;
970
*start = exploded_start;
971
*length = exploded_length;
972
}
973
}
974
GroupsType::new_slice(strategy_explode_groups, false, true)
975
} else {
976
if needs_remap_to_rows {
977
let data_l = data.list()?;
978
assert_eq!(data_l.len(), subgroups.len());
979
let lengths = data_l.lst_lengths();
980
let length_mismatch = match subgroups.as_ref() {
981
GroupsType::Idx(idx) => idx
982
.all()
983
.iter()
984
.zip(&lengths)
985
.any(|(i, l)| i.len() as IdxSize != l.unwrap()),
986
GroupsType::Slice {
987
groups,
988
overlapping: _,
989
monotonic: _,
990
} => groups
991
.iter()
992
.zip(&lengths)
993
.any(|([_, i], l)| *i != l.unwrap()),
994
};
995
996
polars_ensure!(
997
!length_mismatch,
998
expr = self.expr, ShapeMismatch:
999
"the length of the window expression did not match that of the group"
1000
);
1001
1002
data = data_l
1003
.explode(ExplodeOptions {
1004
empty_as_null: false,
1005
keep_nulls: true,
1006
})?
1007
.into_column();
1008
}
1009
GroupsType::Idx(gather_indices.into())
1010
}
1011
.into_sliceable();
1012
1013
Ok(AggregationContext {
1014
state: AggState::NotAggregated(data),
1015
groups: Cow::Owned(final_groups),
1016
update_groups: UpdateGroups::No,
1017
original_len: false,
1018
})
1019
}
1020
1021
fn as_expression(&self) -> Option<&Expr> {
1022
Some(&self.expr)
1023
}
1024
}
1025
1026
fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column {
1027
{
1028
use arrow::Either;
1029
use polars_ops::chunked_array::TakeChunked;
1030
1031
match join_opt_ids {
1032
Either::Left(ids) => unsafe {
1033
IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx))
1034
},
1035
Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) },
1036
}
1037
}
1038
}
1039
1040
/// Simple reducing aggregation can be set by the groups
1041
fn set_by_groups(
1042
s: &Column,
1043
ac: &AggregationContext,
1044
len: usize,
1045
update_groups: bool,
1046
) -> Option<Column> {
1047
if update_groups || !ac.original_len {
1048
return None;
1049
}
1050
if s.dtype().to_physical().is_primitive_numeric() {
1051
let dtype = s.dtype();
1052
let s = s.to_physical_repr();
1053
1054
macro_rules! dispatch {
1055
($ca:expr) => {{ Some(set_numeric($ca, &ac.groups, len)) }};
1056
}
1057
downcast_as_macro_arg_physical!(&s, dispatch)
1058
.map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap())
1059
.map(Column::from)
1060
} else {
1061
None
1062
}
1063
}
1064
1065
fn set_numeric<T: PolarsNumericType>(
1066
ca: &ChunkedArray<T>,
1067
groups: &GroupsType,
1068
len: usize,
1069
) -> Series {
1070
let mut values = Vec::with_capacity(len);
1071
let ptr: *mut T::Native = values.as_mut_ptr();
1072
// SAFETY:
1073
// we will write from different threads but we will never alias.
1074
let sync_ptr_values = unsafe { SyncPtr::new(ptr) };
1075
1076
if ca.null_count() == 0 {
1077
let ca = ca.rechunk();
1078
match groups {
1079
GroupsType::Idx(groups) => {
1080
let agg_vals = ca.cont_slice().expect("rechunked");
1081
POOL.install(|| {
1082
agg_vals
1083
.par_iter()
1084
.zip(groups.all().par_iter())
1085
.for_each(|(v, g)| {
1086
let ptr = sync_ptr_values.get();
1087
for idx in g.as_slice() {
1088
debug_assert!((*idx as usize) < len);
1089
unsafe { *ptr.add(*idx as usize) = *v }
1090
}
1091
})
1092
})
1093
},
1094
GroupsType::Slice { groups, .. } => {
1095
let agg_vals = ca.cont_slice().expect("rechunked");
1096
POOL.install(|| {
1097
agg_vals
1098
.par_iter()
1099
.zip(groups.par_iter())
1100
.for_each(|(v, [start, g_len])| {
1101
let ptr = sync_ptr_values.get();
1102
let start = *start as usize;
1103
let end = start + *g_len as usize;
1104
for idx in start..end {
1105
debug_assert!(idx < len);
1106
unsafe { *ptr.add(idx) = *v }
1107
}
1108
})
1109
});
1110
},
1111
}
1112
1113
// SAFETY: we have written all slots
1114
unsafe { values.set_len(len) }
1115
ChunkedArray::<T>::new_vec(ca.name().clone(), values).into_series()
1116
} else {
1117
// We don't use a mutable bitmap as bits will have race conditions!
1118
// A single byte might alias if we write from single threads.
1119
let mut validity: Vec<bool> = vec![false; len];
1120
let validity_ptr = validity.as_mut_ptr();
1121
let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };
1122
1123
let n_threads = POOL.current_num_threads();
1124
let offsets = _split_offsets(ca.len(), n_threads);
1125
1126
match groups {
1127
GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {
1128
let offset = *offset;
1129
let offset_len = *offset_len;
1130
let ca = ca.slice(offset as i64, offset_len);
1131
let groups = &groups.all()[offset..offset + offset_len];
1132
let values_ptr = sync_ptr_values.get();
1133
let validity_ptr = sync_ptr_validity.get();
1134
1135
ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| {
1136
for idx in g.as_slice() {
1137
let idx = *idx as usize;
1138
debug_assert!(idx < len);
1139
unsafe {
1140
match opt_v {
1141
Some(v) => {
1142
*values_ptr.add(idx) = v;
1143
*validity_ptr.add(idx) = true;
1144
},
1145
None => {
1146
*values_ptr.add(idx) = T::Native::default();
1147
*validity_ptr.add(idx) = false;
1148
},
1149
};
1150
}
1151
}
1152
})
1153
}),
1154
GroupsType::Slice { groups, .. } => {
1155
offsets.par_iter().for_each(|(offset, offset_len)| {
1156
let offset = *offset;
1157
let offset_len = *offset_len;
1158
let ca = ca.slice(offset as i64, offset_len);
1159
let groups = &groups[offset..offset + offset_len];
1160
let values_ptr = sync_ptr_values.get();
1161
let validity_ptr = sync_ptr_validity.get();
1162
1163
for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) {
1164
let start = *start as usize;
1165
let end = start + *g_len as usize;
1166
for idx in start..end {
1167
debug_assert!(idx < len);
1168
unsafe {
1169
match opt_v {
1170
Some(v) => {
1171
*values_ptr.add(idx) = v;
1172
*validity_ptr.add(idx) = true;
1173
},
1174
None => {
1175
*values_ptr.add(idx) = T::Native::default();
1176
*validity_ptr.add(idx) = false;
1177
},
1178
};
1179
}
1180
}
1181
}
1182
})
1183
},
1184
}
1185
// SAFETY: we have written all slots
1186
unsafe { values.set_len(len) }
1187
let validity = Bitmap::from(validity);
1188
let arr = PrimitiveArray::new(
1189
T::get_static_dtype()
1190
.to_physical()
1191
.to_arrow(CompatLevel::newest()),
1192
values.into(),
1193
Some(validity),
1194
);
1195
Series::try_from((ca.name().clone(), arr.boxed())).unwrap()
1196
}
1197
}
1198
1199