Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-expr/src/dispatch/groups_dispatch.rs
7884 views
1
use std::borrow::Cow;
2
use std::sync::Arc;
3
4
use arrow::array::PrimitiveArray;
5
use arrow::bitmap::Bitmap;
6
use arrow::bitmap::bitmask::BitMask;
7
use arrow::trusted_len::TrustMyLength;
8
use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
9
use polars_core::POOL;
10
use polars_core::error::{PolarsResult, polars_bail, polars_ensure};
11
use polars_core::frame::DataFrame;
12
use polars_core::prelude::row_encode::encode_rows_unordered;
13
use polars_core::prelude::{
14
AnyValue, ChunkCast, Column, CompatLevel, Float64Chunked, GroupPositions, GroupsType,
15
IDX_DTYPE, IntoColumn,
16
};
17
use polars_core::scalar::Scalar;
18
use polars_core::series::{ChunkCompareEq, Series};
19
use polars_utils::itertools::Itertools;
20
use polars_utils::pl_str::PlSmallStr;
21
use polars_utils::{IdxSize, UnitVec};
22
use rayon::iter::{IntoParallelIterator, ParallelIterator};
23
24
use crate::prelude::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
25
use crate::state::ExecutionState;
26
27
pub fn reverse<'a>(
28
inputs: &[Arc<dyn PhysicalExpr>],
29
df: &DataFrame,
30
groups: &'a GroupPositions,
31
state: &ExecutionState,
32
) -> PolarsResult<AggregationContext<'a>> {
33
assert_eq!(inputs.len(), 1);
34
35
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
36
37
// Length preserving operation on scalars keeps scalar.
38
if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &ac.agg_state() {
39
return Ok(ac);
40
}
41
42
POOL.install(|| {
43
let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
44
GroupsType::Idx(idx) => idx
45
.into_par_iter()
46
.map(|(first, idx)| {
47
(
48
idx.last().copied().unwrap_or(first),
49
idx.iter().copied().rev().collect(),
50
)
51
})
52
.collect(),
53
GroupsType::Slice {
54
groups,
55
overlapping: _,
56
monotonic: _,
57
} => groups
58
.into_par_iter()
59
.map(|[start, len]| {
60
(
61
start + len.saturating_sub(1),
62
(*start..*start + *len).rev().collect(),
63
)
64
})
65
.collect(),
66
})
67
.into_sliceable();
68
ac.with_groups(positions);
69
});
70
71
Ok(ac)
72
}
73
74
pub fn null_count<'a>(
75
inputs: &[Arc<dyn PhysicalExpr>],
76
df: &DataFrame,
77
groups: &'a GroupPositions,
78
state: &ExecutionState,
79
) -> PolarsResult<AggregationContext<'a>> {
80
assert_eq!(inputs.len(), 1);
81
82
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
83
84
if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
85
*s = s.is_null().cast(&IDX_DTYPE).unwrap().into_column();
86
return Ok(ac);
87
}
88
89
ac.groups();
90
let values = ac.flat_naive();
91
let name = values.name().clone();
92
let Some(validity) = values.rechunk_validity() else {
93
ac.state = AggState::AggregatedScalar(Column::new_scalar(
94
name,
95
(0 as IdxSize).into(),
96
groups.len(),
97
));
98
return Ok(ac);
99
};
100
101
POOL.install(|| {
102
let validity = BitMask::from_bitmap(&validity);
103
let null_count: Vec<IdxSize> = match &**ac.groups.as_ref() {
104
GroupsType::Idx(idx) => idx
105
.into_par_iter()
106
.map(|(_, idx)| {
107
idx.iter()
108
.map(|i| IdxSize::from(!unsafe { validity.get_bit_unchecked(*i as usize) }))
109
.sum::<IdxSize>()
110
})
111
.collect(),
112
GroupsType::Slice {
113
groups,
114
overlapping: _,
115
monotonic: _,
116
} => groups
117
.into_par_iter()
118
.map(|[start, length]| {
119
unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }
120
.unset_bits() as IdxSize
121
})
122
.collect(),
123
};
124
125
ac.state = AggState::AggregatedScalar(Column::new(name, null_count));
126
});
127
128
Ok(ac)
129
}
130
131
pub fn any<'a>(
132
inputs: &[Arc<dyn PhysicalExpr>],
133
df: &DataFrame,
134
groups: &'a GroupPositions,
135
state: &ExecutionState,
136
ignore_nulls: bool,
137
) -> PolarsResult<AggregationContext<'a>> {
138
assert_eq!(inputs.len(), 1);
139
140
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
141
142
if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
143
if ignore_nulls {
144
*s = s
145
.equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
146
.unwrap()
147
.into_column();
148
} else {
149
*s = s
150
.equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
151
.unwrap()
152
.into_column();
153
}
154
return Ok(ac);
155
}
156
157
ac.groups();
158
let values = ac.flat_naive();
159
let values = values.bool()?;
160
let out = unsafe { values.agg_any(ac.groups.as_ref(), ignore_nulls) };
161
ac.state = AggState::AggregatedScalar(out.into_column());
162
163
Ok(ac)
164
}
165
166
pub fn all<'a>(
167
inputs: &[Arc<dyn PhysicalExpr>],
168
df: &DataFrame,
169
groups: &'a GroupPositions,
170
state: &ExecutionState,
171
ignore_nulls: bool,
172
) -> PolarsResult<AggregationContext<'a>> {
173
assert_eq!(inputs.len(), 1);
174
175
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
176
177
if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
178
if ignore_nulls {
179
*s = s
180
.equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
181
.unwrap()
182
.into_column();
183
} else {
184
*s = s
185
.equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
186
.unwrap()
187
.into_column();
188
}
189
return Ok(ac);
190
}
191
192
ac.groups();
193
let values = ac.flat_naive();
194
let values = values.bool()?;
195
let out = unsafe { values.agg_all(ac.groups.as_ref(), ignore_nulls) };
196
ac.state = AggState::AggregatedScalar(out.into_column());
197
198
Ok(ac)
199
}
200
201
#[cfg(feature = "bitwise")]
202
pub fn bitwise_agg<'a>(
203
inputs: &[Arc<dyn PhysicalExpr>],
204
df: &DataFrame,
205
groups: &'a GroupPositions,
206
state: &ExecutionState,
207
op: &'static str,
208
f: impl Fn(&Column, &GroupsType) -> Column,
209
) -> PolarsResult<AggregationContext<'a>> {
210
assert_eq!(inputs.len(), 1);
211
212
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
213
214
if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &ac.state {
215
let dtype = s.dtype();
216
polars_ensure!(
217
dtype.is_bool() | dtype.is_primitive_numeric(),
218
op = op,
219
dtype
220
);
221
return Ok(ac);
222
}
223
224
ac.groups();
225
let values = ac.flat_naive();
226
let out = f(values.as_ref(), ac.groups.as_ref());
227
ac.state = AggState::AggregatedScalar(out.into_column());
228
229
Ok(ac)
230
}
231
232
#[cfg(feature = "bitwise")]
233
pub fn bitwise_and<'a>(
234
inputs: &[Arc<dyn PhysicalExpr>],
235
df: &DataFrame,
236
groups: &'a GroupPositions,
237
state: &ExecutionState,
238
) -> PolarsResult<AggregationContext<'a>> {
239
bitwise_agg(
240
inputs,
241
df,
242
groups,
243
state,
244
"and_reduce",
245
|v, groups| unsafe { v.agg_and(groups) },
246
)
247
}
248
249
#[cfg(feature = "bitwise")]
250
pub fn bitwise_or<'a>(
251
inputs: &[Arc<dyn PhysicalExpr>],
252
df: &DataFrame,
253
groups: &'a GroupPositions,
254
state: &ExecutionState,
255
) -> PolarsResult<AggregationContext<'a>> {
256
bitwise_agg(inputs, df, groups, state, "or_reduce", |v, groups| unsafe {
257
v.agg_or(groups)
258
})
259
}
260
261
#[cfg(feature = "bitwise")]
262
pub fn bitwise_xor<'a>(
263
inputs: &[Arc<dyn PhysicalExpr>],
264
df: &DataFrame,
265
groups: &'a GroupPositions,
266
state: &ExecutionState,
267
) -> PolarsResult<AggregationContext<'a>> {
268
bitwise_agg(
269
inputs,
270
df,
271
groups,
272
state,
273
"xor_reduce",
274
|v, groups| unsafe { v.agg_xor(groups) },
275
)
276
}
277
278
pub fn drop_items<'a>(
279
mut ac: AggregationContext<'a>,
280
predicate: &Bitmap,
281
) -> PolarsResult<AggregationContext<'a>> {
282
// No elements are filtered out.
283
if predicate.unset_bits() == 0 {
284
if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
285
*c = c.as_list().into_column();
286
if c.len() == 1 && ac.groups.len() != 1 {
287
*c = c.new_from_index(0, ac.groups.len());
288
}
289
ac.state = AggState::AggregatedList(std::mem::take(c));
290
ac.update_groups = UpdateGroups::WithSeriesLen;
291
}
292
return Ok(ac);
293
}
294
295
ac.set_original_len(false);
296
297
// All elements are filtered out.
298
if predicate.set_bits() == 0 {
299
let name = ac.agg_state().name();
300
let dtype = ac.agg_state().flat_dtype();
301
302
ac.state = AggState::AggregatedList(Column::new_scalar(
303
name.clone(),
304
Scalar::new(
305
dtype.clone().implode(),
306
AnyValue::List(Series::new_empty(PlSmallStr::EMPTY, dtype)),
307
),
308
ac.groups.len(),
309
));
310
ac.with_update_groups(UpdateGroups::WithSeriesLen);
311
return Ok(ac);
312
}
313
314
if let AggState::AggregatedScalar(c) = &mut ac.state {
315
ac.state = AggState::NotAggregated(std::mem::take(c));
316
ac.groups = Cow::Owned(
317
{
318
let groups = predicate
319
.iter()
320
.enumerate_idx()
321
.map(|(i, p)| [i, IdxSize::from(p)])
322
.collect();
323
GroupsType::new_slice(groups, false, true)
324
}
325
.into_sliceable(),
326
);
327
ac.update_groups = UpdateGroups::No;
328
return Ok(ac);
329
}
330
331
ac.groups();
332
let predicate = BitMask::from_bitmap(predicate);
333
POOL.install(|| {
334
let positions = GroupsType::Idx(match &**ac.groups.as_ref() {
335
GroupsType::Idx(idxs) => idxs
336
.into_par_iter()
337
.map(|(fst, idxs)| {
338
let out = idxs
339
.iter()
340
.copied()
341
.filter(|i| unsafe { predicate.get_bit_unchecked(*i as usize) })
342
.collect::<UnitVec<IdxSize>>();
343
(out.first().copied().unwrap_or(fst), out)
344
})
345
.collect(),
346
GroupsType::Slice {
347
groups,
348
overlapping: _,
349
monotonic: _,
350
} => groups
351
.into_par_iter()
352
.map(|[start, length]| {
353
let predicate =
354
unsafe { predicate.sliced_unchecked(*start as usize, *length as usize) };
355
let num_values = predicate.set_bits();
356
357
if num_values == 0 {
358
(*start, UnitVec::new())
359
} else if num_values == 1 {
360
let item = *start + predicate.leading_zeros() as IdxSize;
361
let mut out = UnitVec::with_capacity(1);
362
out.push(item);
363
(item, out)
364
} else if num_values == *length as usize {
365
(*start, (*start..*start + *length).collect())
366
} else {
367
let out = unsafe {
368
TrustMyLength::new(
369
(0..*length)
370
.filter(|i| predicate.get_bit_unchecked(*i as usize))
371
.map(|i| i + *start),
372
num_values,
373
)
374
};
375
let out = out.collect::<UnitVec<IdxSize>>();
376
377
(out.first().copied().unwrap(), out)
378
}
379
})
380
.collect(),
381
})
382
.into_sliceable();
383
ac.with_groups(positions);
384
});
385
386
Ok(ac)
387
}
388
389
pub fn drop_nans<'a>(
390
inputs: &[Arc<dyn PhysicalExpr>],
391
df: &DataFrame,
392
groups: &'a GroupPositions,
393
state: &ExecutionState,
394
) -> PolarsResult<AggregationContext<'a>> {
395
assert_eq!(inputs.len(), 1);
396
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
397
ac.groups();
398
let predicate = if ac.agg_state().flat_dtype().is_float() {
399
let values = ac.flat_naive();
400
let mut values = values.is_nan().unwrap();
401
values.rechunk_mut();
402
values.downcast_as_array().values().clone()
403
} else {
404
Bitmap::new_with_value(false, 1)
405
};
406
let predicate = !&predicate;
407
drop_items(ac, &predicate)
408
}
409
410
pub fn drop_nulls<'a>(
411
inputs: &[Arc<dyn PhysicalExpr>],
412
df: &DataFrame,
413
groups: &'a GroupPositions,
414
state: &ExecutionState,
415
) -> PolarsResult<AggregationContext<'a>> {
416
assert_eq!(inputs.len(), 1);
417
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
418
ac.groups();
419
let predicate = ac.flat_naive().as_ref().clone();
420
let predicate = predicate.rechunk_to_arrow(CompatLevel::newest());
421
let predicate = predicate
422
.validity()
423
.cloned()
424
.unwrap_or(Bitmap::new_with_value(true, 1));
425
drop_items(ac, &predicate)
426
}
427
428
#[cfg(feature = "moment")]
429
pub fn moment_agg<'a, S: Default>(
430
inputs: &[Arc<dyn PhysicalExpr>],
431
df: &DataFrame,
432
groups: &'a GroupPositions,
433
state: &ExecutionState,
434
435
insert_one: impl Fn(&mut S, f64) + Send + Sync,
436
new_from_slice: impl Fn(&PrimitiveArray<f64>, usize, usize) -> S + Send + Sync,
437
finalize: impl Fn(S) -> Option<f64> + Send + Sync,
438
) -> PolarsResult<AggregationContext<'a>> {
439
assert_eq!(inputs.len(), 1);
440
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
441
442
if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
443
let ca = s.f64()?;
444
*s = ca
445
.iter()
446
.map(|v| {
447
v.and_then(|v| {
448
let mut state = S::default();
449
insert_one(&mut state, v);
450
finalize(state)
451
})
452
})
453
.collect::<Float64Chunked>()
454
.with_name(ca.name().clone())
455
.into_column();
456
return Ok(ac);
457
}
458
459
ac.groups();
460
461
let name = ac.get_values().name().clone();
462
let ca = ac.flat_naive();
463
let ca = ca.f64()?;
464
let ca = ca.rechunk();
465
let arr = ca.downcast_as_array();
466
467
let ca = POOL.install(|| match &**ac.groups.as_ref() {
468
GroupsType::Idx(idx) => {
469
if let Some(validity) = arr.validity().filter(|v| v.unset_bits() > 0) {
470
idx.into_par_iter()
471
.map(|(_, idx)| {
472
let mut state = S::default();
473
for &i in idx.iter() {
474
if unsafe { validity.get_bit_unchecked(i as usize) } {
475
insert_one(&mut state, arr.values()[i as usize]);
476
}
477
}
478
finalize(state)
479
})
480
.collect::<Float64Chunked>()
481
} else {
482
idx.into_par_iter()
483
.map(|(_, idx)| {
484
let mut state = S::default();
485
for &i in idx.iter() {
486
insert_one(&mut state, arr.values()[i as usize]);
487
}
488
finalize(state)
489
})
490
.collect::<Float64Chunked>()
491
}
492
},
493
GroupsType::Slice {
494
groups,
495
overlapping: _,
496
monotonic: _,
497
} => groups
498
.into_par_iter()
499
.map(|[start, length]| finalize(new_from_slice(arr, *start as usize, *length as usize)))
500
.collect::<Float64Chunked>(),
501
});
502
503
ac.state = AggState::AggregatedScalar(ca.with_name(name).into_column());
504
Ok(ac)
505
}
506
507
#[cfg(feature = "moment")]
508
pub fn skew<'a>(
509
inputs: &[Arc<dyn PhysicalExpr>],
510
df: &DataFrame,
511
groups: &'a GroupPositions,
512
state: &ExecutionState,
513
bias: bool,
514
) -> PolarsResult<AggregationContext<'a>> {
515
use polars_compute::moment::SkewState;
516
moment_agg::<SkewState>(
517
inputs,
518
df,
519
groups,
520
state,
521
SkewState::insert_one,
522
SkewState::from_array,
523
|s| s.finalize(bias),
524
)
525
}
526
527
#[cfg(feature = "moment")]
528
pub fn kurtosis<'a>(
529
inputs: &[Arc<dyn PhysicalExpr>],
530
df: &DataFrame,
531
groups: &'a GroupPositions,
532
state: &ExecutionState,
533
fisher: bool,
534
bias: bool,
535
) -> PolarsResult<AggregationContext<'a>> {
536
use polars_compute::moment::KurtosisState;
537
moment_agg::<KurtosisState>(
538
inputs,
539
df,
540
groups,
541
state,
542
KurtosisState::insert_one,
543
KurtosisState::from_array,
544
|s| s.finalize(fisher, bias),
545
)
546
}
547
548
pub fn unique<'a>(
549
inputs: &[Arc<dyn PhysicalExpr>],
550
df: &DataFrame,
551
groups: &'a GroupPositions,
552
state: &ExecutionState,
553
stable: bool,
554
) -> PolarsResult<AggregationContext<'a>> {
555
_ = stable;
556
557
assert_eq!(inputs.len(), 1);
558
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
559
ac.groups();
560
561
if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
562
*c = c.as_list().into_column();
563
if c.len() == 1 && ac.groups.len() != 1 {
564
*c = c.new_from_index(0, ac.groups.len());
565
}
566
ac.state = AggState::AggregatedList(std::mem::take(c));
567
ac.update_groups = UpdateGroups::WithSeriesLen;
568
return Ok(ac);
569
}
570
571
let values = ac.flat_naive().to_physical_repr();
572
let dtype = values.dtype();
573
let values = if dtype.contains_objects() {
574
polars_bail!(opq = unique, dtype);
575
} else if let Some(ca) = values.try_str() {
576
ca.as_binary().into_column()
577
} else if dtype.is_nested() {
578
encode_rows_unordered(&[values])?.into_column()
579
} else {
580
values
581
};
582
583
let values = values.rechunk_to_arrow(CompatLevel::newest());
584
let values = values.as_ref();
585
let state = amortized_unique_from_dtype(values.dtype());
586
587
struct CloneWrapper(Box<dyn AmortizedUnique>);
588
impl Clone for CloneWrapper {
589
fn clone(&self) -> Self {
590
Self(self.0.new_empty())
591
}
592
}
593
594
POOL.install(|| {
595
let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
596
GroupsType::Idx(idx) => idx
597
.into_par_iter()
598
.map_with(CloneWrapper(state), |state, (first, idx)| {
599
let mut idx = idx.clone();
600
unsafe { state.0.retain_unique(values, &mut idx) };
601
(idx.first().copied().unwrap_or(first), idx)
602
})
603
.collect(),
604
GroupsType::Slice {
605
groups,
606
overlapping: _,
607
monotonic: _,
608
} => groups
609
.into_par_iter()
610
.map_with(CloneWrapper(state), |state, [start, len]| {
611
let mut idx = UnitVec::new();
612
state.0.arg_unique(values, &mut idx, *start, *len);
613
(idx.first().copied().unwrap_or(*start), idx)
614
})
615
.collect(),
616
})
617
.into_sliceable();
618
ac.with_groups(positions);
619
});
620
621
Ok(ac)
622
}
623
624
fn fw_bw_fill_null<'a>(
625
inputs: &[Arc<dyn PhysicalExpr>],
626
df: &DataFrame,
627
groups: &'a GroupPositions,
628
state: &ExecutionState,
629
f_idx: impl Fn(
630
std::iter::Copied<std::slice::Iter<'_, IdxSize>>,
631
BitMask<'_>,
632
usize,
633
) -> UnitVec<IdxSize>
634
+ Send
635
+ Sync,
636
f_range: impl Fn(std::ops::Range<IdxSize>, BitMask<'_>, usize) -> UnitVec<IdxSize> + Send + Sync,
637
) -> PolarsResult<AggregationContext<'a>> {
638
assert_eq!(inputs.len(), 1);
639
let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
640
ac.groups();
641
642
if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &mut ac.state {
643
return Ok(ac);
644
}
645
646
let values = ac.flat_naive();
647
let Some(validity) = values.rechunk_validity() else {
648
return Ok(ac);
649
};
650
651
let validity = BitMask::from_bitmap(&validity);
652
POOL.install(|| {
653
let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
654
GroupsType::Idx(idx) => idx
655
.into_par_iter()
656
.map(|(first, idx)| {
657
let idx = f_idx(idx.iter().copied(), validity, idx.len());
658
(idx.first().copied().unwrap_or(first), idx)
659
})
660
.collect(),
661
GroupsType::Slice {
662
groups,
663
overlapping: _,
664
monotonic: _,
665
} => groups
666
.into_par_iter()
667
.map(|[start, len]| {
668
let idx = f_range(*start..*start + *len, validity, *len as usize);
669
(idx.first().copied().unwrap_or(*start), idx)
670
})
671
.collect(),
672
})
673
.into_sliceable();
674
ac.with_groups(positions);
675
});
676
677
Ok(ac)
678
}
679
680
pub fn forward_fill_null<'a>(
681
inputs: &[Arc<dyn PhysicalExpr>],
682
df: &DataFrame,
683
groups: &'a GroupPositions,
684
state: &ExecutionState,
685
limit: Option<IdxSize>,
686
) -> PolarsResult<AggregationContext<'a>> {
687
let limit = limit.unwrap_or(IdxSize::MAX);
688
macro_rules! arg_forward_fill {
689
(
690
$iter:ident,
691
$validity:ident,
692
$length:ident
693
) => {{
694
|$iter, $validity, $length| {
695
let Some(start) = $iter
696
.clone()
697
.position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
698
else {
699
return $iter.collect();
700
};
701
702
let mut idx = UnitVec::with_capacity($length);
703
let mut iter = $iter;
704
idx.extend((&mut iter).take(start));
705
706
let mut current_limit = limit;
707
let mut value = iter.next().unwrap();
708
idx.push(value);
709
710
idx.extend(iter.map(|i| {
711
if unsafe { $validity.get_bit_unchecked(i as usize) } {
712
current_limit = limit;
713
value = i;
714
i
715
} else if current_limit == 0 {
716
i
717
} else {
718
current_limit -= 1;
719
value
720
}
721
}));
722
idx
723
}
724
}};
725
}
726
727
fw_bw_fill_null(
728
inputs,
729
df,
730
groups,
731
state,
732
arg_forward_fill!(iter, validity, length),
733
arg_forward_fill!(iter, validity, length),
734
)
735
}
736
737
pub fn backward_fill_null<'a>(
738
inputs: &[Arc<dyn PhysicalExpr>],
739
df: &DataFrame,
740
groups: &'a GroupPositions,
741
state: &ExecutionState,
742
limit: Option<IdxSize>,
743
) -> PolarsResult<AggregationContext<'a>> {
744
let limit = limit.unwrap_or(IdxSize::MAX);
745
macro_rules! arg_backward_fill {
746
(
747
$iter:ident,
748
$validity:ident,
749
$length:ident
750
) => {{
751
|$iter, $validity, $length| {
752
let Some(start) = $iter
753
.clone()
754
.rev()
755
.position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
756
else {
757
return $iter.collect();
758
};
759
760
let mut idx = UnitVec::from_iter($iter);
761
let mut current_limit = limit;
762
let mut value = idx[$length - start - 1];
763
for i in idx[..$length - start].iter_mut().rev() {
764
if unsafe { $validity.get_bit_unchecked(*i as usize) } {
765
current_limit = limit;
766
value = *i;
767
} else if current_limit != 0 {
768
current_limit -= 1;
769
*i = value;
770
}
771
}
772
773
idx
774
}
775
}};
776
}
777
778
fw_bw_fill_null(
779
inputs,
780
df,
781
groups,
782
state,
783
arg_backward_fill!(iter, validity, length),
784
arg_backward_fill!(iter, validity, length),
785
)
786
}
787
788