Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs
8420 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::array::Array;
3
use arrow::legacy::kernels::take_agg::{
4
take_agg_no_null_primitive_iter_unchecked, take_agg_primitive_iter_unchecked,
5
};
6
use polars_compute::rolling;
7
use polars_compute::rolling::no_nulls::{MaxWindow, MinWindow};
8
use polars_core::frame::group_by::aggregations::{
9
_agg_helper_idx, _agg_helper_slice, _rolling_apply_agg_window_no_nulls,
10
_rolling_apply_agg_window_nulls, _slice_from_offsets, _use_rolling_kernels,
11
};
12
use polars_core::prelude::*;
13
use polars_utils::min_max::MinMax;
14
15
pub fn ca_nan_agg<T, Agg>(ca: &ChunkedArray<T>, min_or_max_fn: Agg) -> Option<T::Native>
16
where
17
T: PolarsFloatType,
18
Agg: Fn(T::Native, T::Native) -> T::Native + Copy,
19
{
20
ca.downcast_iter()
21
.filter_map(|arr| {
22
if arr.null_count() == 0 {
23
arr.values().iter().copied().reduce(min_or_max_fn)
24
} else {
25
arr.iter()
26
.unwrap_optional()
27
.filter_map(|opt| opt.copied())
28
.reduce(min_or_max_fn)
29
}
30
})
31
.reduce(min_or_max_fn)
32
}
33
34
pub fn nan_min_s(s: &Series, name: PlSmallStr) -> Series {
35
match s.dtype() {
36
#[cfg(feature = "dtype-f16")]
37
DataType::Float16 => {
38
let ca = s.f16().unwrap();
39
Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])
40
},
41
DataType::Float32 => {
42
let ca = s.f32().unwrap();
43
Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])
44
},
45
DataType::Float64 => {
46
let ca = s.f64().unwrap();
47
Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])
48
},
49
_ => panic!("expected float"),
50
}
51
}
52
53
pub fn nan_max_s(s: &Series, name: PlSmallStr) -> Series {
54
match s.dtype() {
55
#[cfg(feature = "dtype-f16")]
56
DataType::Float16 => {
57
let ca = s.f16().unwrap();
58
Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])
59
},
60
DataType::Float32 => {
61
let ca = s.f32().unwrap();
62
Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])
63
},
64
DataType::Float64 => {
65
let ca = s.f64().unwrap();
66
Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])
67
},
68
_ => panic!("expected float"),
69
}
70
}
71
72
unsafe fn group_nan_max<T: PolarsFloatType>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series {
73
match groups {
74
GroupsType::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {
75
debug_assert!(idx.len() <= ca.len());
76
if idx.is_empty() {
77
None
78
} else if idx.len() == 1 {
79
ca.get(first as usize)
80
} else {
81
match (ca.has_nulls(), ca.chunks().len()) {
82
(false, 1) => take_agg_no_null_primitive_iter_unchecked(
83
ca.downcast_iter().next().unwrap(),
84
idx.iter().map(|i| *i as usize),
85
)
86
.reduce(MinMax::max_propagate_nan),
87
(_, 1) => take_agg_primitive_iter_unchecked(
88
ca.downcast_iter().next().unwrap(),
89
idx.iter().map(|i| *i as usize),
90
)
91
.reduce(MinMax::max_propagate_nan),
92
_ => {
93
let take = { ca.take_unchecked(idx) };
94
ca_nan_agg(&take, MinMax::max_propagate_nan)
95
},
96
}
97
}
98
}),
99
GroupsType::Slice {
100
groups: groups_slice,
101
overlapping,
102
monotonic,
103
} => {
104
if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, ca.chunks()) {
105
let arr = ca.downcast_iter().next().unwrap();
106
let values = arr.values().as_slice();
107
let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
108
let arr = match arr.validity() {
109
None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _, _>(
110
values,
111
offset_iter,
112
None,
113
),
114
Some(validity) => _rolling_apply_agg_window_nulls::<
115
rolling::nulls::MaxWindow<_>,
116
_,
117
_,
118
_,
119
>(values, validity, offset_iter, None),
120
};
121
ChunkedArray::<T>::from(arr).into_series()
122
} else {
123
_agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
124
debug_assert!(len <= ca.len() as IdxSize);
125
match len {
126
0 => None,
127
1 => ca.get(first as usize),
128
_ => {
129
let arr_group = _slice_from_offsets(ca, first, len);
130
ca_nan_agg(&arr_group, MinMax::max_propagate_nan)
131
},
132
}
133
})
134
}
135
},
136
}
137
}
138
139
unsafe fn group_nan_min<T: PolarsFloatType>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series {
140
match groups {
141
GroupsType::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {
142
debug_assert!(idx.len() <= ca.len());
143
if idx.is_empty() {
144
None
145
} else if idx.len() == 1 {
146
ca.get(first as usize)
147
} else {
148
match (ca.has_nulls(), ca.chunks().len()) {
149
(false, 1) => take_agg_no_null_primitive_iter_unchecked(
150
ca.downcast_iter().next().unwrap(),
151
idx.iter().map(|i| *i as usize),
152
)
153
.reduce(MinMax::min_propagate_nan),
154
(_, 1) => take_agg_primitive_iter_unchecked(
155
ca.downcast_iter().next().unwrap(),
156
idx.iter().map(|i| *i as usize),
157
)
158
.reduce(MinMax::min_propagate_nan),
159
_ => {
160
let take = { ca.take_unchecked(idx) };
161
ca_nan_agg(&take, MinMax::min_propagate_nan)
162
},
163
}
164
}
165
}),
166
GroupsType::Slice {
167
groups: groups_slice,
168
overlapping,
169
monotonic,
170
} => {
171
if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, ca.chunks()) {
172
let arr = ca.downcast_iter().next().unwrap();
173
let values = arr.values().as_slice();
174
let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
175
let arr = match arr.validity() {
176
None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _, _>(
177
values,
178
offset_iter,
179
None,
180
),
181
Some(validity) => _rolling_apply_agg_window_nulls::<
182
rolling::nulls::MinWindow<_>,
183
_,
184
_,
185
_,
186
>(values, validity, offset_iter, None),
187
};
188
ChunkedArray::<T>::from(arr).into_series()
189
} else {
190
_agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
191
debug_assert!(len <= ca.len() as IdxSize);
192
match len {
193
0 => None,
194
1 => ca.get(first as usize),
195
_ => {
196
let arr_group = _slice_from_offsets(ca, first, len);
197
ca_nan_agg(&arr_group, MinMax::min_propagate_nan)
198
},
199
}
200
})
201
}
202
},
203
}
204
}
205
206
/// # Safety
207
/// `groups` must be in bounds.
208
pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsType) -> Series {
209
match s.dtype() {
210
#[cfg(feature = "dtype-f16")]
211
DataType::Float16 => {
212
let ca = s.f16().unwrap();
213
group_nan_min(ca, groups)
214
},
215
DataType::Float32 => {
216
let ca = s.f32().unwrap();
217
group_nan_min(ca, groups)
218
},
219
DataType::Float64 => {
220
let ca = s.f64().unwrap();
221
group_nan_min(ca, groups)
222
},
223
_ => panic!("expected float"),
224
}
225
}
226
227
/// # Safety
228
/// `groups` must be in bounds.
229
pub unsafe fn group_agg_nan_max_s(s: &Series, groups: &GroupsType) -> Series {
230
match s.dtype() {
231
#[cfg(feature = "dtype-f16")]
232
DataType::Float16 => {
233
let ca = s.f16().unwrap();
234
group_nan_max(ca, groups)
235
},
236
DataType::Float32 => {
237
let ca = s.f32().unwrap();
238
group_nan_max(ca, groups)
239
},
240
DataType::Float64 => {
241
let ca = s.f64().unwrap();
242
group_nan_max(ca, groups)
243
},
244
_ => panic!("expected float"),
245
}
246
}
247
248