Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/arg_min_max.rs
6939 views
1
use argminmax::ArgMinMax;
2
use arrow::array::Array;
3
use polars_core::chunked_array::ops::float_sorted_arg_max::{
4
float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
5
};
6
use polars_core::series::IsSorted;
7
use polars_core::with_match_categorical_physical_type;
8
9
use super::*;
10
11
/// Argmin/ Argmax
12
pub trait ArgAgg {
13
/// Get the index of the minimal value
14
fn arg_min(&self) -> Option<usize>;
15
/// Get the index of the maximal value
16
fn arg_max(&self) -> Option<usize>;
17
}
18
19
macro_rules! with_match_physical_numeric_polars_type {(
20
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
21
) => ({
22
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
23
use DataType::*;
24
match $key_type {
25
#[cfg(feature = "dtype-i8")]
26
Int8 => __with_ty__! { Int8Type },
27
#[cfg(feature = "dtype-i16")]
28
Int16 => __with_ty__! { Int16Type },
29
Int32 => __with_ty__! { Int32Type },
30
Int64 => __with_ty__! { Int64Type },
31
#[cfg(feature = "dtype-u8")]
32
UInt8 => __with_ty__! { UInt8Type },
33
#[cfg(feature = "dtype-u16")]
34
UInt16 => __with_ty__! { UInt16Type },
35
UInt32 => __with_ty__! { UInt32Type },
36
UInt64 => __with_ty__! { UInt64Type },
37
Float32 => __with_ty__! { Float32Type },
38
Float64 => __with_ty__! { Float64Type },
39
dt => panic!("not implemented for dtype {:?}", dt),
40
}
41
})}
42
43
impl ArgAgg for Series {
44
fn arg_min(&self) -> Option<usize> {
45
use DataType::*;
46
let phys_s = self.to_physical_repr();
47
match self.dtype() {
48
#[cfg(feature = "dtype-categorical")]
49
Categorical(cats, _) => {
50
with_match_categorical_physical_type!(cats.physical(), |$C| {
51
let ca = self.cat::<$C>().unwrap();
52
if ca.null_count() == ca.len() {
53
return None;
54
}
55
ca.iter_str()
56
.enumerate()
57
.flat_map(|(idx, val)| val.map(|val| (idx, val)))
58
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
59
.map(|tpl| tpl.0)
60
})
61
},
62
#[cfg(feature = "dtype-categorical")]
63
Enum(_, _) => phys_s.arg_min(),
64
Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),
65
String => {
66
let ca = self.str().unwrap();
67
arg_min_str(ca)
68
},
69
Boolean => {
70
let ca = self.bool().unwrap();
71
arg_min_bool(ca)
72
},
73
dt if dt.is_primitive_numeric() => {
74
with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
75
let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
76
arg_min_numeric_dispatch(ca)
77
})
78
},
79
_ => None,
80
}
81
}
82
83
fn arg_max(&self) -> Option<usize> {
84
use DataType::*;
85
let phys_s = self.to_physical_repr();
86
match self.dtype() {
87
#[cfg(feature = "dtype-categorical")]
88
Categorical(cats, _) => {
89
with_match_categorical_physical_type!(cats.physical(), |$C| {
90
let ca = self.cat::<$C>().unwrap();
91
if ca.null_count() == ca.len() {
92
return None;
93
}
94
ca.iter_str()
95
.enumerate()
96
.flat_map(|(idx, val)| val.map(|val| (idx, val)))
97
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
98
.map(|tpl| tpl.0)
99
})
100
},
101
#[cfg(feature = "dtype-categorical")]
102
Enum(_, _) => phys_s.arg_max(),
103
Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),
104
String => {
105
let ca = self.str().unwrap();
106
arg_max_str(ca)
107
},
108
Boolean => {
109
let ca = self.bool().unwrap();
110
arg_max_bool(ca)
111
},
112
dt if dt.is_primitive_numeric() => {
113
with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
114
let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
115
arg_max_numeric_dispatch(ca)
116
})
117
},
118
_ => None,
119
}
120
}
121
}
122
123
fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
124
where
125
T: PolarsNumericType,
126
for<'b> &'b [T::Native]: ArgMinMax,
127
{
128
if ca.null_count() == ca.len() {
129
None
130
} else if T::get_static_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
131
arg_max_float_sorted(ca)
132
} else if let Ok(vals) = ca.cont_slice() {
133
arg_max_numeric_slice(vals, ca.is_sorted_flag())
134
} else {
135
arg_max_numeric(ca)
136
}
137
}
138
139
fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
140
where
141
T: PolarsNumericType,
142
for<'b> &'b [T::Native]: ArgMinMax,
143
{
144
if ca.null_count() == ca.len() {
145
None
146
} else if let Ok(vals) = ca.cont_slice() {
147
arg_min_numeric_slice(vals, ca.is_sorted_flag())
148
} else {
149
arg_min_numeric(ca)
150
}
151
}
152
153
fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
154
ca.first_true_idx().or_else(|| ca.first_false_idx())
155
}
156
157
/// # Safety
158
/// `ca` has a float dtype, has at least one non-null value and is sorted.
159
fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
160
where
161
T: PolarsNumericType,
162
{
163
let out = match ca.is_sorted_flag() {
164
IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
165
IsSorted::Descending => float_arg_max_sorted_descending(ca),
166
_ => unreachable!(),
167
};
168
Some(out)
169
}
170
171
fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
172
ca.first_false_idx().or_else(|| ca.first_true_idx())
173
}
174
175
fn arg_min_str(ca: &StringChunked) -> Option<usize> {
176
if ca.null_count() == ca.len() {
177
return None;
178
}
179
match ca.is_sorted_flag() {
180
IsSorted::Ascending => ca.first_non_null(),
181
IsSorted::Descending => ca.last_non_null(),
182
IsSorted::Not => ca
183
.iter()
184
.enumerate()
185
.flat_map(|(idx, val)| val.map(|val| (idx, val)))
186
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
187
.map(|tpl| tpl.0),
188
}
189
}
190
191
fn arg_max_str(ca: &StringChunked) -> Option<usize> {
192
if ca.null_count() == ca.len() {
193
return None;
194
}
195
match ca.is_sorted_flag() {
196
IsSorted::Ascending => ca.last_non_null(),
197
IsSorted::Descending => ca.first_non_null(),
198
IsSorted::Not => ca
199
.iter()
200
.enumerate()
201
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
202
.map(|tpl| tpl.0),
203
}
204
}
205
206
fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
207
where
208
T: PolarsNumericType,
209
for<'b> &'b [T::Native]: ArgMinMax,
210
{
211
match ca.is_sorted_flag() {
212
IsSorted::Ascending => ca.first_non_null(),
213
IsSorted::Descending => ca.last_non_null(),
214
IsSorted::Not => {
215
ca.downcast_iter()
216
.fold((None, None, 0), |acc, arr| {
217
if arr.len() == 0 {
218
return acc;
219
}
220
let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
221
arr.into_iter()
222
.enumerate()
223
.flat_map(|(idx, val)| val.map(|val| (idx, *val)))
224
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
225
} else {
226
// When no nulls & array not empty => we can use fast argminmax
227
let min_idx: usize = arr.values().as_slice().argmin();
228
Some((min_idx, arr.value(min_idx)))
229
};
230
231
let new_offset: usize = acc.2 + arr.len();
232
match acc {
233
(Some(_), Some(acc_v), offset) => match chunk_min {
234
Some((idx, val)) if val < acc_v => {
235
(Some(idx + offset), Some(val), new_offset)
236
},
237
_ => (acc.0, acc.1, new_offset),
238
},
239
(None, None, offset) => match chunk_min {
240
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
241
None => (None, None, new_offset),
242
},
243
_ => unreachable!(),
244
}
245
})
246
.0
247
},
248
}
249
}
250
251
fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
252
where
253
T: PolarsNumericType,
254
for<'b> &'b [T::Native]: ArgMinMax,
255
{
256
match ca.is_sorted_flag() {
257
IsSorted::Ascending => ca.last_non_null(),
258
IsSorted::Descending => ca.first_non_null(),
259
IsSorted::Not => {
260
ca.downcast_iter()
261
.fold((None, None, 0), |acc, arr| {
262
if arr.len() == 0 {
263
return acc;
264
}
265
let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
266
// When there are nulls, we should compare Option<T::Native>
267
arr.into_iter()
268
.enumerate()
269
.flat_map(|(idx, val)| val.map(|val| (idx, *val)))
270
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
271
} else {
272
// When no nulls & array not empty => we can use fast argminmax
273
let max_idx: usize = arr.values().as_slice().argmax();
274
Some((max_idx, arr.value(max_idx)))
275
};
276
277
let new_offset: usize = acc.2 + arr.len();
278
match acc {
279
(Some(_), Some(acc_v), offset) => match chunk_max {
280
Some((idx, val)) if acc_v < val => {
281
(Some(idx + offset), Some(val), new_offset)
282
},
283
_ => (acc.0, acc.1, new_offset),
284
},
285
(None, None, offset) => match chunk_max {
286
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
287
None => (None, None, new_offset),
288
},
289
_ => unreachable!(),
290
}
291
})
292
.0
293
},
294
}
295
}
296
297
fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
298
where
299
for<'a> &'a [T]: ArgMinMax,
300
{
301
match is_sorted {
302
// all vals are not null guarded by cont_slice
303
IsSorted::Ascending => Some(0),
304
// all vals are not null guarded by cont_slice
305
IsSorted::Descending => Some(vals.len() - 1),
306
IsSorted::Not => Some(vals.argmin()), // assumes not empty
307
}
308
}
309
310
fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
311
where
312
for<'a> &'a [T]: ArgMinMax,
313
{
314
match is_sorted {
315
// all vals are not null guarded by cont_slice
316
IsSorted::Ascending => Some(vals.len() - 1),
317
// all vals are not null guarded by cont_slice
318
IsSorted::Descending => Some(0),
319
IsSorted::Not => Some(vals.argmax()), // assumes not empty
320
}
321
}
322
323