Path: blob/main/crates/polars-ops/src/series/ops/arg_min_max.rs
6939 views
use argminmax::ArgMinMax;1use arrow::array::Array;2use polars_core::chunked_array::ops::float_sorted_arg_max::{3float_arg_max_sorted_ascending, float_arg_max_sorted_descending,4};5use polars_core::series::IsSorted;6use polars_core::with_match_categorical_physical_type;78use super::*;910/// Argmin/ Argmax11pub trait ArgAgg {12/// Get the index of the minimal value13fn arg_min(&self) -> Option<usize>;14/// Get the index of the maximal value15fn arg_max(&self) -> Option<usize>;16}1718macro_rules! with_match_physical_numeric_polars_type {(19$key_type:expr, | $_:tt $T:ident | $($body:tt)*20) => ({21macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}22use DataType::*;23match $key_type {24#[cfg(feature = "dtype-i8")]25Int8 => __with_ty__! { Int8Type },26#[cfg(feature = "dtype-i16")]27Int16 => __with_ty__! { Int16Type },28Int32 => __with_ty__! { Int32Type },29Int64 => __with_ty__! { Int64Type },30#[cfg(feature = "dtype-u8")]31UInt8 => __with_ty__! { UInt8Type },32#[cfg(feature = "dtype-u16")]33UInt16 => __with_ty__! { UInt16Type },34UInt32 => __with_ty__! { UInt32Type },35UInt64 => __with_ty__! { UInt64Type },36Float32 => __with_ty__! { Float32Type },37Float64 => __with_ty__! { Float64Type },38dt => panic!("not implemented for dtype {:?}", dt),39}40})}4142impl ArgAgg for Series {43fn arg_min(&self) -> Option<usize> {44use DataType::*;45let phys_s = self.to_physical_repr();46match self.dtype() {47#[cfg(feature = "dtype-categorical")]48Categorical(cats, _) => {49with_match_categorical_physical_type!(cats.physical(), |$C| {50let ca = self.cat::<$C>().unwrap();51if ca.null_count() == ca.len() {52return None;53}54ca.iter_str()55.enumerate()56.flat_map(|(idx, val)| val.map(|val| (idx, val)))57.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })58.map(|tpl| tpl.0)59})60},61#[cfg(feature = "dtype-categorical")]62Enum(_, _) => phys_s.arg_min(),63Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),64String => {65let ca = self.str().unwrap();66arg_min_str(ca)67},68Boolean => {69let ca = self.bool().unwrap();70arg_min_bool(ca)71},72dt if dt.is_primitive_numeric() => {73with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {74let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();75arg_min_numeric_dispatch(ca)76})77},78_ => None,79}80}8182fn arg_max(&self) -> Option<usize> {83use DataType::*;84let phys_s = self.to_physical_repr();85match self.dtype() {86#[cfg(feature = "dtype-categorical")]87Categorical(cats, _) => {88with_match_categorical_physical_type!(cats.physical(), |$C| {89let ca = self.cat::<$C>().unwrap();90if ca.null_count() == ca.len() {91return None;92}93ca.iter_str()94.enumerate()95.flat_map(|(idx, val)| val.map(|val| (idx, val)))96.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })97.map(|tpl| tpl.0)98})99},100#[cfg(feature = "dtype-categorical")]101Enum(_, _) => phys_s.arg_max(),102Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),103String => {104let ca = self.str().unwrap();105arg_max_str(ca)106},107Boolean => {108let ca = self.bool().unwrap();109arg_max_bool(ca)110},111dt if dt.is_primitive_numeric() => {112with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {113let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();114arg_max_numeric_dispatch(ca)115})116},117_ => None,118}119}120}121122fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>123where124T: PolarsNumericType,125for<'b> &'b [T::Native]: ArgMinMax,126{127if ca.null_count() == ca.len() {128None129} else if T::get_static_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {130arg_max_float_sorted(ca)131} else if let Ok(vals) = ca.cont_slice() {132arg_max_numeric_slice(vals, ca.is_sorted_flag())133} else {134arg_max_numeric(ca)135}136}137138fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>139where140T: PolarsNumericType,141for<'b> &'b [T::Native]: ArgMinMax,142{143if ca.null_count() == ca.len() {144None145} else if let Ok(vals) = ca.cont_slice() {146arg_min_numeric_slice(vals, ca.is_sorted_flag())147} else {148arg_min_numeric(ca)149}150}151152fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {153ca.first_true_idx().or_else(|| ca.first_false_idx())154}155156/// # Safety157/// `ca` has a float dtype, has at least one non-null value and is sorted.158fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>159where160T: PolarsNumericType,161{162let out = match ca.is_sorted_flag() {163IsSorted::Ascending => float_arg_max_sorted_ascending(ca),164IsSorted::Descending => float_arg_max_sorted_descending(ca),165_ => unreachable!(),166};167Some(out)168}169170fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {171ca.first_false_idx().or_else(|| ca.first_true_idx())172}173174fn arg_min_str(ca: &StringChunked) -> Option<usize> {175if ca.null_count() == ca.len() {176return None;177}178match ca.is_sorted_flag() {179IsSorted::Ascending => ca.first_non_null(),180IsSorted::Descending => ca.last_non_null(),181IsSorted::Not => ca182.iter()183.enumerate()184.flat_map(|(idx, val)| val.map(|val| (idx, val)))185.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })186.map(|tpl| tpl.0),187}188}189190fn arg_max_str(ca: &StringChunked) -> Option<usize> {191if ca.null_count() == ca.len() {192return None;193}194match ca.is_sorted_flag() {195IsSorted::Ascending => ca.last_non_null(),196IsSorted::Descending => ca.first_non_null(),197IsSorted::Not => ca198.iter()199.enumerate()200.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })201.map(|tpl| tpl.0),202}203}204205fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>206where207T: PolarsNumericType,208for<'b> &'b [T::Native]: ArgMinMax,209{210match ca.is_sorted_flag() {211IsSorted::Ascending => ca.first_non_null(),212IsSorted::Descending => ca.last_non_null(),213IsSorted::Not => {214ca.downcast_iter()215.fold((None, None, 0), |acc, arr| {216if arr.len() == 0 {217return acc;218}219let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {220arr.into_iter()221.enumerate()222.flat_map(|(idx, val)| val.map(|val| (idx, *val)))223.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })224} else {225// When no nulls & array not empty => we can use fast argminmax226let min_idx: usize = arr.values().as_slice().argmin();227Some((min_idx, arr.value(min_idx)))228};229230let new_offset: usize = acc.2 + arr.len();231match acc {232(Some(_), Some(acc_v), offset) => match chunk_min {233Some((idx, val)) if val < acc_v => {234(Some(idx + offset), Some(val), new_offset)235},236_ => (acc.0, acc.1, new_offset),237},238(None, None, offset) => match chunk_min {239Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),240None => (None, None, new_offset),241},242_ => unreachable!(),243}244})245.0246},247}248}249250fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>251where252T: PolarsNumericType,253for<'b> &'b [T::Native]: ArgMinMax,254{255match ca.is_sorted_flag() {256IsSorted::Ascending => ca.last_non_null(),257IsSorted::Descending => ca.first_non_null(),258IsSorted::Not => {259ca.downcast_iter()260.fold((None, None, 0), |acc, arr| {261if arr.len() == 0 {262return acc;263}264let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {265// When there are nulls, we should compare Option<T::Native>266arr.into_iter()267.enumerate()268.flat_map(|(idx, val)| val.map(|val| (idx, *val)))269.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })270} else {271// When no nulls & array not empty => we can use fast argminmax272let max_idx: usize = arr.values().as_slice().argmax();273Some((max_idx, arr.value(max_idx)))274};275276let new_offset: usize = acc.2 + arr.len();277match acc {278(Some(_), Some(acc_v), offset) => match chunk_max {279Some((idx, val)) if acc_v < val => {280(Some(idx + offset), Some(val), new_offset)281},282_ => (acc.0, acc.1, new_offset),283},284(None, None, offset) => match chunk_max {285Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),286None => (None, None, new_offset),287},288_ => unreachable!(),289}290})291.0292},293}294}295296fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>297where298for<'a> &'a [T]: ArgMinMax,299{300match is_sorted {301// all vals are not null guarded by cont_slice302IsSorted::Ascending => Some(0),303// all vals are not null guarded by cont_slice304IsSorted::Descending => Some(vals.len() - 1),305IsSorted::Not => Some(vals.argmin()), // assumes not empty306}307}308309fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>310where311for<'a> &'a [T]: ArgMinMax,312{313match is_sorted {314// all vals are not null guarded by cont_slice315IsSorted::Ascending => Some(vals.len() - 1),316// all vals are not null guarded by cont_slice317IsSorted::Descending => Some(0),318IsSorted::Not => Some(vals.argmax()), // assumes not empty319}320}321322323