Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-utils/src/arg_min_max.rs
7884 views
1
use crate::float16::pf16;
2
3
pub trait ArgMinMax {
4
fn argmin(&self) -> usize;
5
6
fn argmax(&self) -> usize;
7
}
8
9
macro_rules! impl_argminmax {
10
($T:ty) => {
11
impl ArgMinMax for $T {
12
fn argmin(&self) -> usize {
13
argminmax::ArgMinMax::argmin(self)
14
}
15
16
fn argmax(&self) -> usize {
17
argminmax::ArgMinMax::argmax(self)
18
}
19
}
20
};
21
}
22
23
impl_argminmax!(&[u8]);
24
impl_argminmax!(&[u16]);
25
impl_argminmax!(&[u32]);
26
impl_argminmax!(&[u64]);
27
impl_argminmax!(&[i8]);
28
impl_argminmax!(&[i16]);
29
impl_argminmax!(&[i32]);
30
impl_argminmax!(&[i64]);
31
impl_argminmax!(&[f32]);
32
impl_argminmax!(&[f64]);
33
34
impl ArgMinMax for &[i128] {
35
fn argmin(&self) -> usize {
36
let mut min_val = i128::MAX;
37
let mut min_idx = 0;
38
for (idx, val) in self.iter().enumerate() {
39
if *val < min_val {
40
min_val = *val;
41
min_idx = idx;
42
}
43
}
44
min_idx
45
}
46
47
fn argmax(&self) -> usize {
48
let mut max_val = i128::MIN;
49
let mut max_idx = 0;
50
for (idx, val) in self.iter().enumerate() {
51
if *val > max_val {
52
max_val = *val;
53
max_idx = idx;
54
}
55
}
56
max_idx
57
}
58
}
59
60
impl ArgMinMax for &[u128] {
61
fn argmin(&self) -> usize {
62
let mut min_val = u128::MAX;
63
let mut min_idx = 0;
64
for (idx, val) in self.iter().enumerate() {
65
if *val < min_val {
66
min_val = *val;
67
min_idx = idx;
68
}
69
}
70
min_idx
71
}
72
73
fn argmax(&self) -> usize {
74
let mut max_val = u128::MIN;
75
let mut max_idx = 0;
76
for (idx, val) in self.iter().enumerate() {
77
if *val > max_val {
78
max_val = *val;
79
max_idx = idx;
80
}
81
}
82
max_idx
83
}
84
}
85
86
impl ArgMinMax for &[pf16] {
87
fn argmin(&self) -> usize {
88
let transmuted: &&[half::f16] = unsafe { std::mem::transmute(self) };
89
argminmax::ArgMinMax::argmin(transmuted)
90
}
91
92
fn argmax(&self) -> usize {
93
let transmuted: &&[half::f16] = unsafe { std::mem::transmute(self) };
94
argminmax::ArgMinMax::argmax(transmuted)
95
}
96
}
97
98