Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::array::Array;
3
use arrow::legacy::kernels::take_agg::{
4
take_agg_no_null_primitive_iter_unchecked, take_agg_primitive_iter_unchecked,
5
};
6
use polars_compute::rolling;
7
use polars_compute::rolling::no_nulls::{MaxWindow, MinWindow};
8
use polars_core::frame::group_by::aggregations::{
9
_agg_helper_idx, _agg_helper_slice, _rolling_apply_agg_window_no_nulls,
10
_rolling_apply_agg_window_nulls, _slice_from_offsets, _use_rolling_kernels,
11
};
12
use polars_core::prelude::*;
13
use polars_utils::min_max::MinMax;
14
15
pub fn ca_nan_agg<T, Agg>(ca: &ChunkedArray<T>, min_or_max_fn: Agg) -> Option<T::Native>
16
where
17
T: PolarsFloatType,
18
Agg: Fn(T::Native, T::Native) -> T::Native + Copy,
19
{
20
ca.downcast_iter()
21
.filter_map(|arr| {
22
if arr.null_count() == 0 {
23
arr.values().iter().copied().reduce(min_or_max_fn)
24
} else {
25
arr.iter()
26
.unwrap_optional()
27
.filter_map(|opt| opt.copied())
28
.reduce(min_or_max_fn)
29
}
30
})
31
.reduce(min_or_max_fn)
32
}
33
34
pub fn nan_min_s(s: &Series, name: PlSmallStr) -> Series {
35
match s.dtype() {
36
DataType::Float32 => {
37
let ca = s.f32().unwrap();
38
Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])
39
},
40
DataType::Float64 => {
41
let ca = s.f64().unwrap();
42
Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])
43
},
44
_ => panic!("expected float"),
45
}
46
}
47
48
pub fn nan_max_s(s: &Series, name: PlSmallStr) -> Series {
49
match s.dtype() {
50
DataType::Float32 => {
51
let ca = s.f32().unwrap();
52
Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])
53
},
54
DataType::Float64 => {
55
let ca = s.f64().unwrap();
56
Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])
57
},
58
_ => panic!("expected float"),
59
}
60
}
61
62
unsafe fn group_nan_max<T: PolarsFloatType>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series {
63
match groups {
64
GroupsType::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {
65
debug_assert!(idx.len() <= ca.len());
66
if idx.is_empty() {
67
None
68
} else if idx.len() == 1 {
69
ca.get(first as usize)
70
} else {
71
match (ca.has_nulls(), ca.chunks().len()) {
72
(false, 1) => take_agg_no_null_primitive_iter_unchecked(
73
ca.downcast_iter().next().unwrap(),
74
idx.iter().map(|i| *i as usize),
75
MinMax::max_propagate_nan,
76
),
77
(_, 1) => take_agg_primitive_iter_unchecked(
78
ca.downcast_iter().next().unwrap(),
79
idx.iter().map(|i| *i as usize),
80
MinMax::max_propagate_nan,
81
),
82
_ => {
83
let take = { ca.take_unchecked(idx) };
84
ca_nan_agg(&take, MinMax::max_propagate_nan)
85
},
86
}
87
}
88
}),
89
GroupsType::Slice {
90
groups: groups_slice,
91
..
92
} => {
93
if _use_rolling_kernels(groups_slice, ca.chunks()) {
94
let arr = ca.downcast_iter().next().unwrap();
95
let values = arr.values().as_slice();
96
let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
97
let arr = match arr.validity() {
98
None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(
99
values,
100
offset_iter,
101
None,
102
),
103
Some(validity) => _rolling_apply_agg_window_nulls::<
104
rolling::nulls::MaxWindow<_>,
105
_,
106
_,
107
>(values, validity, offset_iter, None),
108
};
109
ChunkedArray::<T>::from(arr).into_series()
110
} else {
111
_agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
112
debug_assert!(len <= ca.len() as IdxSize);
113
match len {
114
0 => None,
115
1 => ca.get(first as usize),
116
_ => {
117
let arr_group = _slice_from_offsets(ca, first, len);
118
ca_nan_agg(&arr_group, MinMax::max_propagate_nan)
119
},
120
}
121
})
122
}
123
},
124
}
125
}
126
127
unsafe fn group_nan_min<T: PolarsFloatType>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series {
128
match groups {
129
GroupsType::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {
130
debug_assert!(idx.len() <= ca.len());
131
if idx.is_empty() {
132
None
133
} else if idx.len() == 1 {
134
ca.get(first as usize)
135
} else {
136
match (ca.has_nulls(), ca.chunks().len()) {
137
(false, 1) => take_agg_no_null_primitive_iter_unchecked(
138
ca.downcast_iter().next().unwrap(),
139
idx.iter().map(|i| *i as usize),
140
MinMax::min_propagate_nan,
141
),
142
(_, 1) => take_agg_primitive_iter_unchecked(
143
ca.downcast_iter().next().unwrap(),
144
idx.iter().map(|i| *i as usize),
145
MinMax::min_propagate_nan,
146
),
147
_ => {
148
let take = { ca.take_unchecked(idx) };
149
ca_nan_agg(&take, MinMax::min_propagate_nan)
150
},
151
}
152
}
153
}),
154
GroupsType::Slice {
155
groups: groups_slice,
156
..
157
} => {
158
if _use_rolling_kernels(groups_slice, ca.chunks()) {
159
let arr = ca.downcast_iter().next().unwrap();
160
let values = arr.values().as_slice();
161
let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
162
let arr = match arr.validity() {
163
None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(
164
values,
165
offset_iter,
166
None,
167
),
168
Some(validity) => _rolling_apply_agg_window_nulls::<
169
rolling::nulls::MinWindow<_>,
170
_,
171
_,
172
>(values, validity, offset_iter, None),
173
};
174
ChunkedArray::<T>::from(arr).into_series()
175
} else {
176
_agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
177
debug_assert!(len <= ca.len() as IdxSize);
178
match len {
179
0 => None,
180
1 => ca.get(first as usize),
181
_ => {
182
let arr_group = _slice_from_offsets(ca, first, len);
183
ca_nan_agg(&arr_group, MinMax::min_propagate_nan)
184
},
185
}
186
})
187
}
188
},
189
}
190
}
191
192
/// # Safety
193
/// `groups` must be in bounds.
194
pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsType) -> Series {
195
match s.dtype() {
196
DataType::Float32 => {
197
let ca = s.f32().unwrap();
198
group_nan_min(ca, groups)
199
},
200
DataType::Float64 => {
201
let ca = s.f64().unwrap();
202
group_nan_min(ca, groups)
203
},
204
_ => panic!("expected float"),
205
}
206
}
207
208
/// # Safety
209
/// `groups` must be in bounds.
210
pub unsafe fn group_agg_nan_max_s(s: &Series, groups: &GroupsType) -> Series {
211
match s.dtype() {
212
DataType::Float32 => {
213
let ca = s.f32().unwrap();
214
group_nan_max(ca, groups)
215
},
216
DataType::Float64 => {
217
let ca = s.f64().unwrap();
218
group_nan_max(ca, groups)
219
},
220
_ => panic!("expected float"),
221
}
222
}
223
224