Path: blob/main/crates/polars-expr/src/reduce/min_max_by.rs
8424 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::borrow::Cow;2use std::marker::PhantomData;34use num_traits::Bounded;5use polars_core::chunked_array::arg_min_max::{6arg_max_binary, arg_max_bool, arg_max_numeric, arg_min_binary, arg_min_bool, arg_min_numeric,7};8use polars_core::with_match_physical_integer_polars_type;9use polars_utils::arg_min_max::ArgMinMax;10use polars_utils::float::IsFloat;11use polars_utils::min_max::MinMax;1213use super::*;14use crate::reduce::first_last::new_last_reduction;1516pub fn new_min_by_reduction(17dtype: DataType,18by_dtype: DataType,19) -> PolarsResult<Box<dyn GroupedReduction>> {20// TODO: Move the error checks up and make this function infallible21use DataType::*;22use SelectPayloadGroupedReduction as SPGR;23let payload = new_last_reduction(dtype.clone());24Ok(match &by_dtype {25Boolean => Box::new(SPGR::new(by_dtype, BooleanMinSelector, payload)),26#[cfg(all(feature = "dtype-f16", feature = "propagate_nans"))]27#[cfg(feature = "dtype-f16")]28Float16 => Box::new(SPGR::new(29by_dtype,30MinSelector::<Float16Type>(PhantomData),31payload,32)),33Float32 => Box::new(SPGR::new(34by_dtype,35MinSelector::<Float32Type>(PhantomData),36payload,37)),38Float64 => Box::new(SPGR::new(39by_dtype,40MinSelector::<Float64Type>(PhantomData),41payload,42)),43Null => Box::new(NullGroupedReduction::new(Scalar::null(dtype))),44String | Binary => Box::new(SPGR::new(by_dtype, BinaryMinSelector, payload)),45_ if by_dtype.is_integer() || by_dtype.is_temporal() || by_dtype.is_enum() => {46with_match_physical_integer_polars_type!(by_dtype.to_physical(), |$T| {47Box::new(SPGR::new(by_dtype, MinSelector::<$T>(PhantomData), payload))48})49},50#[cfg(feature = "dtype-decimal")]51Decimal(_, _) => Box::new(SPGR::new(52by_dtype,53MinSelector::<Int128Type>(PhantomData),54payload,55)),56#[cfg(feature = "dtype-categorical")]57Categorical(cats, map) => with_match_categorical_physical_type!(cats.physical(), |$C| {58let map = map.clone();59Box::new(SPGR::new(by_dtype, CatMinSelector::<$C>(map, PhantomData), payload))60}),61_ => {62polars_bail!(InvalidOperation: "`min_by` operation not supported for by dtype `{by_dtype}`")63},64})65}6667pub fn new_max_by_reduction(68dtype: DataType,69by_dtype: DataType,70) -> PolarsResult<Box<dyn GroupedReduction>> {71// TODO: Move the error checks up and make this function infallible72use DataType::*;73use SelectPayloadGroupedReduction as SPGR;74let payload = new_last_reduction(dtype.clone());75Ok(match &by_dtype {76Boolean => Box::new(SPGR::new(by_dtype, BooleanMaxSelector, payload)),77#[cfg(all(feature = "dtype-f16", feature = "propagate_nans"))]78#[cfg(feature = "dtype-f16")]79Float16 => Box::new(SPGR::new(80by_dtype,81MaxSelector::<Float16Type>(PhantomData),82payload,83)),84Float32 => Box::new(SPGR::new(85by_dtype,86MaxSelector::<Float32Type>(PhantomData),87payload,88)),89Float64 => Box::new(SPGR::new(90by_dtype,91MaxSelector::<Float64Type>(PhantomData),92payload,93)),94Null => Box::new(NullGroupedReduction::new(Scalar::null(dtype))),95String | Binary => Box::new(SPGR::new(by_dtype, BinaryMaxSelector, payload)),96_ if by_dtype.is_integer() || by_dtype.is_temporal() || by_dtype.is_enum() => {97with_match_physical_integer_polars_type!(by_dtype.to_physical(), |$T| {98Box::new(SPGR::new(by_dtype, MaxSelector::<$T>(PhantomData), payload))99})100},101#[cfg(feature = "dtype-decimal")]102Decimal(_, _) => Box::new(SPGR::new(103by_dtype,104MaxSelector::<Int128Type>(PhantomData),105payload,106)),107#[cfg(feature = "dtype-categorical")]108Categorical(cats, map) => with_match_categorical_physical_type!(cats.physical(), |$C| {109let map = map.clone();110Box::new(SPGR::new(by_dtype, CatMaxSelector::<$C>(map, PhantomData), payload))111}),112_ => {113polars_bail!(InvalidOperation: "`max_by` operation not supported for by dtype `{by_dtype}`")114},115})116}117118trait SelectReducer: Clone + Send + Sync + 'static {119type Value: Clone + Send + Sync + 'static;120type Dtype: PolarsPhysicalType;121122fn init(&self) -> Self::Value;123124fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {125Cow::Borrowed(s)126}127128fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize>;129130fn select_one(131&self,132a: &mut Self::Value,133b: <Self::Dtype as PolarsDataType>::Physical<'_>,134) -> bool;135136fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool;137}138139struct MinSelector<T>(PhantomData<T>);140struct MaxSelector<T>(PhantomData<T>);141142impl<T> Clone for MinSelector<T> {143fn clone(&self) -> Self {144Self(PhantomData)145}146}147148impl<T> SelectReducer for MinSelector<T>149where150T: PolarsNumericType,151ChunkedArray<T>: ChunkAgg<T::Native>,152for<'b> &'b [T::Native]: ArgMinMax,153{154type Value = T::Native;155type Dtype = T;156157fn init(&self) -> Self::Value {158if T::Native::is_float() {159T::Native::nan_value()160} else {161T::Native::max_value()162}163}164165fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {166s.to_physical_repr()167}168169fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize> {170arg_min_numeric(ca).filter(|idx| {171let val = unsafe { ca.value_unchecked(*idx) };172self.select_one(v, val)173})174}175176fn select_one(177&self,178a: &mut Self::Value,179b: <Self::Dtype as PolarsDataType>::Physical<'_>,180) -> bool {181let better = b.nan_max_lt(a);182if better {183*a = b;184}185better186}187188fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool {189self.select_one(a, *b)190}191}192193impl<T> Clone for MaxSelector<T> {194fn clone(&self) -> Self {195Self(PhantomData)196}197}198199impl<T> SelectReducer for MaxSelector<T>200where201T: PolarsNumericType,202ChunkedArray<T>: ChunkAgg<T::Native>,203for<'b> &'b [T::Native]: ArgMinMax,204{205type Value = T::Native;206type Dtype = T;207208fn init(&self) -> Self::Value {209if T::Native::is_float() {210T::Native::nan_value()211} else {212T::Native::min_value()213}214}215216fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {217s.to_physical_repr()218}219220fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize> {221arg_max_numeric(ca).filter(|idx| {222let val = unsafe { ca.value_unchecked(*idx) };223self.select_one(v, val)224})225}226227fn select_one(228&self,229a: &mut Self::Value,230b: <Self::Dtype as PolarsDataType>::Physical<'_>,231) -> bool {232let better = b.nan_min_gt(a);233if better {234*a = b;235}236better237}238239fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool {240self.select_one(a, *b)241}242}243244#[derive(Clone)]245struct BinaryMinSelector;246#[derive(Clone)]247struct BinaryMaxSelector;248249impl SelectReducer for BinaryMinSelector {250type Dtype = BinaryType;251type Value = Option<Vec<u8>>;252253fn init(&self) -> Self::Value {254// There's no "maximum string" initializer.255None256}257258#[inline(always)]259fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {260Cow::Owned(s.cast(&DataType::Binary).unwrap())261}262263fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize> {264arg_min_binary(ca).filter(|idx| {265let val = unsafe { ca.value_unchecked(*idx) };266self.select_one(v, val)267})268}269270fn select_one(271&self,272a: &mut Self::Value,273b: <Self::Dtype as PolarsDataType>::Physical<'_>,274) -> bool {275if let Some(av) = a {276if b < av.as_slice() {277av.clear();278av.extend_from_slice(b);279true280} else {281false282}283} else {284*a = Some(b.to_vec());285true286}287}288289fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool {290if let Some(bv) = b {291self.select_one(a, bv)292} else {293false294}295}296}297298impl SelectReducer for BinaryMaxSelector {299type Dtype = BinaryType;300type Value = Vec<u8>;301302fn init(&self) -> Self::Value {303// Empty string is <= any other string, so can initialize max with it.304Vec::new()305}306307#[inline(always)]308fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {309Cow::Owned(s.cast(&DataType::Binary).unwrap())310}311312fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize> {313arg_max_binary(ca).filter(|idx| {314let val = unsafe { ca.value_unchecked(*idx) };315self.select_one(v, val)316})317}318319fn select_one(320&self,321a: &mut Self::Value,322b: <Self::Dtype as PolarsDataType>::Physical<'_>,323) -> bool {324let better = b > a.as_slice();325if better {326a.clear();327a.extend_from_slice(b);328}329better330}331332fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool {333self.select_one(a, b)334}335}336337#[derive(Clone)]338struct BooleanMinSelector;339#[derive(Clone)]340struct BooleanMaxSelector;341342impl SelectReducer for BooleanMinSelector {343type Value = bool;344type Dtype = BooleanType;345346fn init(&self) -> Self::Value {347true348}349350fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize> {351arg_min_bool(ca).filter(|idx| {352let val = unsafe { ca.value_unchecked(*idx) };353self.select_one(v, val)354})355}356357fn select_one(358&self,359a: &mut Self::Value,360b: <Self::Dtype as PolarsDataType>::Physical<'_>,361) -> bool {362#[allow(clippy::bool_comparison)]363let better = b < *a;364if better {365*a = b;366}367better368}369370fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool {371self.select_one(a, *b)372}373}374375impl SelectReducer for BooleanMaxSelector {376type Value = bool;377type Dtype = BooleanType;378379fn init(&self) -> Self::Value {380false381}382383fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize> {384arg_max_bool(ca).filter(|idx| {385let val = unsafe { ca.value_unchecked(*idx) };386self.select_one(v, val)387})388}389390fn select_one(391&self,392a: &mut Self::Value,393b: <Self::Dtype as PolarsDataType>::Physical<'_>,394) -> bool {395#[allow(clippy::bool_comparison)]396let better = b > *a;397if better {398*a = b;399}400better401}402403fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool {404self.select_one(a, *b)405}406}407408#[cfg(feature = "dtype-categorical")]409struct CatMinSelector<T>(Arc<CategoricalMapping>, PhantomData<T>);410411#[cfg(feature = "dtype-categorical")]412impl<T> Clone for CatMinSelector<T> {413fn clone(&self) -> Self {414Self(self.0.clone(), PhantomData)415}416}417418#[cfg(feature = "dtype-categorical")]419impl<T: PolarsCategoricalType> SelectReducer for CatMinSelector<T> {420type Dtype = T::PolarsPhysical;421type Value = T::Native;422423fn init(&self) -> Self::Value {424T::Native::max_value() // Ensures it's invalid, preferring the other value.425}426427#[inline(always)]428fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {429s.to_physical_repr()430}431432fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize> {433use polars_core::chunked_array::arg_min_max::arg_min_opt_iter;434let arg_min = arg_min_opt_iter(ca.iter().map(|cat| self.0.cat_to_str(cat?.as_cat())));435arg_min.filter(|idx| {436let val = unsafe { ca.value_unchecked(*idx) };437self.select_one(v, val)438})439}440441fn select_one(442&self,443a: &mut Self::Value,444b: <Self::Dtype as PolarsDataType>::Physical<'_>,445) -> bool {446let Some(b_s) = self.0.cat_to_str(b.as_cat()) else {447return false;448};449let Some(a_s) = self.0.cat_to_str(a.as_cat()) else {450*a = b;451return true;452};453454let better = b_s < a_s;455if better {456*a = b;457}458better459}460461fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool {462self.select_one(a, *b)463}464}465466#[cfg(feature = "dtype-categorical")]467struct CatMaxSelector<T>(Arc<CategoricalMapping>, PhantomData<T>);468469#[cfg(feature = "dtype-categorical")]470impl<T> Clone for CatMaxSelector<T> {471fn clone(&self) -> Self {472Self(self.0.clone(), PhantomData)473}474}475476#[cfg(feature = "dtype-categorical")]477impl<T: PolarsCategoricalType> SelectReducer for CatMaxSelector<T> {478type Dtype = T::PolarsPhysical;479type Value = T::Native;480481fn init(&self) -> Self::Value {482T::Native::max_value() // Ensures it's invalid, preferring the other value.483}484485#[inline(always)]486fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {487s.to_physical_repr()488}489490fn select_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>) -> Option<usize> {491use polars_core::chunked_array::arg_min_max::arg_max_opt_iter;492let arg_max = arg_max_opt_iter(ca.iter().map(|cat| self.0.cat_to_str(cat?.as_cat())));493arg_max.filter(|idx| {494let val = unsafe { ca.value_unchecked(*idx) };495self.select_one(v, val)496})497}498499fn select_one(500&self,501a: &mut Self::Value,502b: <Self::Dtype as PolarsDataType>::Physical<'_>,503) -> bool {504let Some(b_s) = self.0.cat_to_str(b.as_cat()) else {505return false;506};507let Some(a_s) = self.0.cat_to_str(a.as_cat()) else {508*a = b;509return true;510};511512let better = b_s > a_s;513if better {514*a = b;515}516better517}518519fn select_combine(&self, a: &mut Self::Value, b: &Self::Value) -> bool {520self.select_one(a, *b)521}522}523524struct SelectPayloadGroupedReduction<R: SelectReducer> {525values: Vec<R::Value>,526mask: MutableBitmap,527evicted_values: Vec<R::Value>,528evicted_mask: BitmapBuilder,529in_dtype: DataType,530reducer: R,531payload: Box<dyn GroupedReduction>,532533tmp_subset: Vec<IdxSize>,534tmp_group_idxs: Vec<IdxSize>,535}536537impl<R: SelectReducer> SelectPayloadGroupedReduction<R> {538fn new(in_dtype: DataType, reducer: R, payload: Box<dyn GroupedReduction>) -> Self {539Self {540values: Vec::new(),541mask: MutableBitmap::new(),542evicted_values: Vec::new(),543evicted_mask: BitmapBuilder::new(),544in_dtype,545reducer,546payload,547tmp_subset: Vec::new(),548tmp_group_idxs: Vec::new(),549}550}551}552553impl<R> GroupedReduction for SelectPayloadGroupedReduction<R>554where555R: SelectReducer,556{557fn new_empty(&self) -> Box<dyn GroupedReduction> {558Box::new(Self::new(559self.in_dtype.clone(),560self.reducer.clone(),561self.payload.new_empty(),562))563}564565fn reserve(&mut self, additional: usize) {566self.values.reserve(additional);567self.mask.reserve(additional);568self.payload.reserve(additional);569}570571fn resize(&mut self, num_groups: IdxSize) {572self.values.resize(num_groups as usize, self.reducer.init());573self.mask.resize(num_groups as usize, false);574self.payload.resize(num_groups);575}576577fn update_group(578&mut self,579values: &[&Column],580group_idx: IdxSize,581_seq_id: u64,582) -> PolarsResult<()> {583assert!(values.len() == 2);584let payload_values = values[0];585let ord_values = values[1];586assert_eq!(ord_values.dtype(), &self.in_dtype);587588let ord_values = self589.reducer590.cast_series(ord_values.as_materialized_series());591let ca: &ChunkedArray<R::Dtype> = ord_values.as_ref().as_ref().as_ref();592593if let Some(selected) = self594.reducer595.select_ca(&mut self.values[group_idx as usize], ca)596{597self.mask.set(group_idx as usize, true);598let selected_val = payload_values.new_from_index(selected, 1);599self.payload.update_group(&[&selected_val], group_idx, 0)?;600}601602Ok(())603}604605unsafe fn update_groups_while_evicting(606&mut self,607values: &[&Column],608subset: &[IdxSize],609group_idxs: &[EvictIdx],610_seq_id: u64,611) -> PolarsResult<()> {612assert!(values.len() == 2);613let payload_values = values[0];614let ord_values = values[1];615assert!(ord_values.dtype() == &self.in_dtype);616assert!(subset.len() == group_idxs.len());617618// TODO: @scalar-opt619let ord_values = self620.reducer621.cast_series(ord_values.as_materialized_series());622let ca: &ChunkedArray<R::Dtype> = ord_values.as_ref().as_ref().as_ref();623let arr = ca.downcast_as_array();624unsafe {625self.tmp_subset.clear();626self.tmp_group_idxs.clear();627628// SAFETY: indices are in-bounds guaranteed by trait.629for (i, g) in subset.iter().zip(group_idxs) {630let ov = arr.get_unchecked(*i as usize);631let grp = self.values.get_unchecked_mut(g.idx());632if g.should_evict() {633self.evicted_values634.push(core::mem::replace(grp, self.reducer.init()));635self.evicted_mask.push(self.mask.get_unchecked(g.idx()));636self.mask.set_unchecked(g.idx(), ov.is_some());637if let Some(v) = ov {638self.reducer.select_one(grp, v);639}640self.tmp_subset.push(*i);641self.tmp_group_idxs.push(g.0);642} else if let Some(v) = ov {643if self.mask.get_unchecked(g.idx()) {644if self.reducer.select_one(grp, v) {645self.tmp_subset.push(*i);646self.tmp_group_idxs.push(g.0);647}648} else {649self.mask.set_unchecked(g.idx(), true);650self.reducer.select_one(grp, v);651self.tmp_subset.push(*i);652self.tmp_group_idxs.push(g.0);653}654}655}656657self.payload.update_groups_while_evicting(658&[payload_values],659&self.tmp_subset,660EvictIdx::cast_slice(&self.tmp_group_idxs),6610, // seq_id is unused662)?;663}664Ok(())665}666667unsafe fn combine_subset(668&mut self,669other: &dyn GroupedReduction,670subset: &[IdxSize],671group_idxs: &[IdxSize],672) -> PolarsResult<()> {673let other = other.as_any().downcast_ref::<Self>().unwrap();674assert!(self.in_dtype == other.in_dtype);675assert!(subset.len() == group_idxs.len());676unsafe {677self.tmp_subset.clear();678self.tmp_group_idxs.clear();679680// SAFETY: indices are in-bounds guaranteed by trait.681for (i, g) in subset.iter().zip(group_idxs) {682let o = other.mask.get_unchecked(*i as usize);683if o {684let v = other.values.get_unchecked(*i as usize);685let grp = self.values.get_unchecked_mut(*g as usize);686if self.reducer.select_combine(grp, v) | !self.mask.get_unchecked(*g as usize) {687self.tmp_subset.push(*i);688self.tmp_group_idxs.push(*g);689}690self.mask.set_unchecked(*g as usize, true);691}692}693694self.payload.combine_subset(695other.payload.as_ref(),696&self.tmp_subset,697&self.tmp_group_idxs,698)?;699}700Ok(())701}702703fn take_evictions(&mut self) -> Box<dyn GroupedReduction> {704Box::new(Self {705values: core::mem::take(&mut self.evicted_values),706mask: core::mem::take(&mut self.evicted_mask).into_mut(),707evicted_values: Vec::new(),708evicted_mask: BitmapBuilder::new(),709in_dtype: self.in_dtype.clone(),710reducer: self.reducer.clone(),711payload: self.payload.take_evictions(),712tmp_group_idxs: Vec::new(),713tmp_subset: Vec::new(),714})715}716717fn finalize(&mut self) -> PolarsResult<Series> {718let mask = core::mem::take(&mut self.mask);719drop(core::mem::take(&mut self.values));720drop(core::mem::take(&mut self.tmp_group_idxs));721drop(core::mem::take(&mut self.tmp_subset));722723// TODO @ minmax-by: better way to combine payload and mask.724let data = self.payload.finalize()?;725let mca = BooleanChunked::from_bitmap(PlSmallStr::EMPTY, mask.freeze());726let nulls = Series::full_null(data.name().clone(), 1, data.dtype());727data.zip_with(&mca, &nulls)728}729730fn as_any(&self) -> &dyn Any {731self732}733}734735736