Path: blob/main/crates/polars-expr/src/reduce/first_last.rs
8420 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::fmt::Debug;2use std::marker::PhantomData;34use polars_core::frame::row::AnyValueBufferTrusted;5use polars_core::with_match_physical_numeric_polars_type;67use super::*;89pub fn new_first_reduction(dtype: DataType) -> Box<dyn GroupedReduction> {10new_reduction_with_policy(dtype, First)11}1213pub fn new_last_reduction(dtype: DataType) -> Box<dyn GroupedReduction> {14new_reduction_with_policy(dtype, Last)15}1617pub fn new_item_reduction(dtype: DataType, allow_empty: bool) -> Box<dyn GroupedReduction> {18new_reduction_with_policy(dtype, Item { allow_empty })19}2021fn new_reduction_with_policy<P: Policy + 'static>(22dtype: DataType,23policy: P,24) -> Box<dyn GroupedReduction> {25use DataType::*;26use VecGroupedReduction as VGR;27match dtype {28Boolean => Box::new(VecGroupedReduction::new(29dtype,30BoolFirstLastReducer(policy),31)),32_ if dtype.is_primitive_numeric()33|| dtype.is_temporal()34|| dtype.is_decimal()35|| dtype.is_categorical() =>36{37with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| {38Box::new(VGR::new(dtype, NumFirstLastReducer::<_, $T>(policy, PhantomData)))39})40},41String | Binary => Box::new(VecGroupedReduction::new(42dtype,43BinaryFirstLastReducer(policy),44)),45_ => Box::new(GenericFirstLastGroupedReduction::new(dtype, policy)),46}47}4849trait Policy: Copy + Send + Sync + 'static {50type Count: Default + Clone + Copy + Send + Sync + 'static;5152fn add_count(_a: &mut Self::Count, _b: usize) {}53fn combine_count(_a: &mut Self::Count, _b: &Self::Count) {}54fn check_count(_count: Self::Count, _allow_empty: bool) -> PolarsResult<()> {55Ok(())56}5758fn index(self, len: usize) -> usize;59fn should_replace(self, new: u64, old: u64) -> bool;6061#[inline(always)]62fn item_policy(self) -> Option<bool> {63None64}65}6667#[derive(Clone, Copy)]68pub struct First;69impl Policy for First {70type Count = ();7172fn index(self, _len: usize) -> usize {73074}7576fn should_replace(self, new: u64, old: u64) -> bool {77// Subtracting 1 with wrapping leaves all order unchanged, except it78// makes 0 (no value) the largest possible.79new.wrapping_sub(1) < old.wrapping_sub(1)80}81}8283#[derive(Clone, Copy)]84pub struct Last;85impl Policy for Last {86type Count = ();8788fn index(self, len: usize) -> usize {89len - 190}9192fn should_replace(self, new: u64, old: u64) -> bool {93new >= old94}95}9697#[derive(Clone, Copy)]98struct Item {99allow_empty: bool,100}101impl Policy for Item {102type Count = u8;103104fn add_count(a: &mut Self::Count, b: usize) {105*a = a.saturating_add(b.min(255) as u8);106}107108fn combine_count(a: &mut Self::Count, b: &Self::Count) {109*a = a.saturating_add(*b);110}111112fn index(self, _len: usize) -> usize {1130114}115116fn should_replace(self, _new: u64, old: u64) -> bool {117old == 0118}119120fn item_policy(self) -> Option<bool> {121Some(self.allow_empty)122}123124fn check_count(count: Self::Count, allow_empty: bool) -> PolarsResult<()> {125polars_ensure!(126(allow_empty && count == 0) || count == 1,127item_agg_count_not_one = count,128allow_empty = allow_empty129);130Ok(())131}132}133134struct NumFirstLastReducer<P, T>(P, PhantomData<T>);135136#[derive(Clone, Debug, Default)]137struct Value<T, C> {138value: Option<T>,139seq: u64,140count: C,141}142143impl<P: Policy, T> Clone for NumFirstLastReducer<P, T> {144fn clone(&self) -> Self {145Self(self.0, PhantomData)146}147}148149impl<P, T> Reducer for NumFirstLastReducer<P, T>150where151P: Policy,152T: PolarsNumericType,153{154type Dtype = T;155type Value = Value<T::Native, P::Count>;156157fn init(&self) -> Self::Value {158Value::default()159}160161fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {162s.to_physical_repr()163}164165fn combine(&self, a: &mut Self::Value, b: &Self::Value) {166if self.0.should_replace(b.seq, a.seq) {167a.value = b.value;168a.seq = b.seq;169}170P::combine_count(&mut a.count, &b.count);171}172173fn reduce_one(&self, a: &mut Self::Value, b: Option<T::Native>, seq_id: u64) {174if self.0.should_replace(seq_id, a.seq) {175a.value = b;176a.seq = seq_id;177}178P::add_count(&mut a.count, b.is_some() as usize);179}180181fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>, seq_id: u64) {182if !ca.is_empty() && self.0.should_replace(seq_id, v.seq) {183let val = ca.get(self.0.index(ca.len()));184v.value = val;185v.seq = seq_id;186}187P::add_count(&mut v.count, ca.len());188}189190fn finish(191&self,192v: Vec<Self::Value>,193m: Option<Bitmap>,194dtype: &DataType,195) -> PolarsResult<Series> {196assert!(m.is_none()); // This should only be used with VecGroupedReduction.197if let Some(allow_empty) = self.0.item_policy() {198check_item_count_is_one::<_, P>(&v, allow_empty)?;199}200let ca: ChunkedArray<T> = v201.into_iter()202.map(|red_val| red_val.value)203.collect_ca(PlSmallStr::EMPTY);204let s = ca.into_series();205unsafe { s.from_physical_unchecked(dtype) }206}207}208209struct BinaryFirstLastReducer<P>(P);210211impl<P: Policy> Clone for BinaryFirstLastReducer<P> {212fn clone(&self) -> Self {213Self(self.0)214}215}216217pub fn replace_opt_bytes(l: &mut Option<Vec<u8>>, r: Option<&[u8]>) {218match (l, r) {219(Some(l), Some(r)) => {220l.clear();221l.extend_from_slice(r);222},223(l, r) => *l = r.map(|s| s.to_owned()),224}225}226227impl<P> Reducer for BinaryFirstLastReducer<P>228where229P: Policy,230{231type Dtype = BinaryType;232type Value = Value<Vec<u8>, P::Count>;233234fn init(&self) -> Self::Value {235Value::default()236}237238fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {239Cow::Owned(s.cast(&DataType::Binary).unwrap())240}241242fn combine(&self, a: &mut Self::Value, b: &Self::Value) {243if self.0.should_replace(b.seq, a.seq) {244a.value.clone_from(&b.value);245a.seq = b.seq;246}247P::combine_count(&mut a.count, &b.count);248}249250fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>, seq_id: u64) {251if self.0.should_replace(seq_id, a.seq) {252replace_opt_bytes(&mut a.value, b);253a.seq = seq_id;254}255P::add_count(&mut a.count, b.is_some() as usize);256}257258fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>, seq_id: u64) {259if !ca.is_empty() && self.0.should_replace(seq_id, v.seq) {260replace_opt_bytes(&mut v.value, ca.get(self.0.index(ca.len())));261v.seq = seq_id;262}263P::add_count(&mut v.count, ca.len());264}265266fn finish(267&self,268v: Vec<Self::Value>,269m: Option<Bitmap>,270dtype: &DataType,271) -> PolarsResult<Series> {272assert!(m.is_none()); // This should only be used with VecGroupedReduction.273if let Some(allow_empty) = self.0.item_policy() {274check_item_count_is_one::<_, P>(&v, allow_empty)?;275}276let ca: BinaryChunked = v277.into_iter()278.map(|Value { value, .. }| value)279.collect_ca(PlSmallStr::EMPTY);280ca.into_series().cast(dtype)281}282}283284#[derive(Clone)]285struct BoolFirstLastReducer<P: Policy>(P);286287impl<P> Reducer for BoolFirstLastReducer<P>288where289P: Policy,290{291type Dtype = BooleanType;292type Value = Value<bool, P::Count>;293294fn init(&self) -> Self::Value {295Value::default()296}297298fn combine(&self, a: &mut Self::Value, b: &Self::Value) {299if self.0.should_replace(b.seq, a.seq) {300a.value = b.value;301a.seq = b.seq;302}303P::combine_count(&mut a.count, &b.count);304}305306fn reduce_one(&self, a: &mut Self::Value, b: Option<bool>, seq_id: u64) {307if self.0.should_replace(seq_id, a.seq) {308a.value = b;309a.seq = seq_id;310}311P::add_count(&mut a.count, b.is_some() as usize);312}313314fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>, seq_id: u64) {315if !ca.is_empty() && self.0.should_replace(seq_id, v.seq) {316v.value = ca.get(self.0.index(ca.len()));317v.seq = seq_id;318}319P::add_count(&mut v.count, ca.len());320}321322fn finish(323&self,324v: Vec<Self::Value>,325m: Option<Bitmap>,326_dtype: &DataType,327) -> PolarsResult<Series> {328assert!(m.is_none()); // This should only be used with VecGroupedReduction.329if let Some(allow_empty) = self.0.item_policy() {330check_item_count_is_one::<_, P>(&v, allow_empty)?;331}332let ca: BooleanChunked = v333.into_iter()334.map(|Value { value, .. }| value)335.collect_ca(PlSmallStr::EMPTY);336Ok(ca.into_series())337}338}339340struct GenericFirstLastGroupedReduction<P: Policy> {341in_dtype: DataType,342policy: P,343values: Vec<AnyValue<'static>>,344seqs: Vec<u64>,345counts: Vec<P::Count>,346evicted_values: Vec<AnyValue<'static>>,347evicted_seqs: Vec<u64>,348evicted_counts: Vec<P::Count>,349}350351impl<P: Policy> GenericFirstLastGroupedReduction<P> {352fn new(in_dtype: DataType, policy: P) -> Self {353Self {354in_dtype,355policy,356values: Vec::new(),357seqs: Vec::new(),358counts: Vec::new(),359evicted_values: Vec::new(),360evicted_seqs: Vec::new(),361evicted_counts: Vec::new(),362}363}364}365366impl<P: Policy + 'static> GroupedReduction for GenericFirstLastGroupedReduction<P> {367fn new_empty(&self) -> Box<dyn GroupedReduction> {368Box::new(Self::new(self.in_dtype.clone(), self.policy))369}370371fn reserve(&mut self, additional: usize) {372self.values.reserve(additional);373self.seqs.reserve(additional);374self.counts.reserve(additional);375}376377fn resize(&mut self, num_groups: IdxSize) {378self.values.resize(num_groups as usize, AnyValue::Null);379self.seqs.resize(num_groups as usize, 0);380self.counts.resize(num_groups as usize, P::Count::default());381}382383fn update_group(384&mut self,385values: &[&Column],386group_idx: IdxSize,387seq_id: u64,388) -> PolarsResult<()> {389let &[values] = values else { unreachable!() };390assert!(values.dtype() == &self.in_dtype);391if !values.is_empty() {392let seq_id = seq_id + 1; // We use 0 for 'no value'.393if self394.policy395.should_replace(seq_id, self.seqs[group_idx as usize])396{397self.values[group_idx as usize] =398values.get(self.policy.index(values.len()))?.into_static();399self.seqs[group_idx as usize] = seq_id;400}401P::add_count(&mut self.counts[group_idx as usize], values.len());402}403Ok(())404}405406unsafe fn update_groups_while_evicting(407&mut self,408values: &[&Column],409subset: &[IdxSize],410group_idxs: &[EvictIdx],411seq_id: u64,412) -> PolarsResult<()> {413let &[values] = values else { unreachable!() };414assert!(values.dtype() == &self.in_dtype);415assert!(subset.len() == group_idxs.len());416let seq_id = seq_id + 1; // We use 0 for 'no value'.417for (i, g) in subset.iter().zip(group_idxs) {418let grp_val = self.values.get_unchecked_mut(g.idx());419let grp_seq = self.seqs.get_unchecked_mut(g.idx());420let grp_count = self.counts.get_unchecked_mut(g.idx());421if g.should_evict() {422self.evicted_values423.push(core::mem::replace(grp_val, AnyValue::Null));424self.evicted_seqs.push(core::mem::replace(grp_seq, 0));425self.evicted_counts.push(core::mem::take(grp_count));426}427if self.policy.should_replace(seq_id, *grp_seq) {428*grp_val = values.get_unchecked(*i as usize).into_static();429*grp_seq = seq_id;430}431P::add_count(self.counts.get_unchecked_mut(g.idx()), 1);432}433Ok(())434}435436unsafe fn combine_subset(437&mut self,438other: &dyn GroupedReduction,439subset: &[IdxSize],440group_idxs: &[IdxSize],441) -> PolarsResult<()> {442let other = other.as_any().downcast_ref::<Self>().unwrap();443assert!(self.in_dtype == other.in_dtype);444assert!(subset.len() == group_idxs.len());445for (i, g) in group_idxs.iter().enumerate() {446let si = *subset.get_unchecked(i) as usize;447if self.policy.should_replace(448*other.seqs.get_unchecked(si),449*self.seqs.get_unchecked(*g as usize),450) {451*self.values.get_unchecked_mut(*g as usize) =452other.values.get_unchecked(si).clone();453*self.seqs.get_unchecked_mut(*g as usize) = *other.seqs.get_unchecked(si);454}455P::combine_count(456self.counts.get_unchecked_mut(*g as usize),457other.counts.get_unchecked(si),458);459}460Ok(())461}462463fn take_evictions(&mut self) -> Box<dyn GroupedReduction> {464Box::new(Self {465in_dtype: self.in_dtype.clone(),466policy: self.policy,467values: core::mem::take(&mut self.evicted_values),468seqs: core::mem::take(&mut self.evicted_seqs),469counts: core::mem::take(&mut self.evicted_counts),470evicted_values: Vec::new(),471evicted_seqs: Vec::new(),472evicted_counts: Vec::new(),473})474}475476fn finalize(&mut self) -> PolarsResult<Series> {477self.seqs.clear();478if let Some(allow_empty) = self.policy.item_policy() {479for count in self.counts.iter() {480P::check_count(*count, allow_empty)?;481}482}483let phys_type = self.in_dtype.to_physical();484let mut buf = AnyValueBufferTrusted::new(&phys_type, self.values.len());485for v in core::mem::take(&mut self.values) {486// SAFETY: v is cast to physical.487unsafe { buf.add_unchecked_owned_physical(&v.to_physical()) };488}489// SAFETY: dtype is valid for series.490unsafe { buf.into_series().from_physical_unchecked(&self.in_dtype) }491}492493fn as_any(&self) -> &dyn Any {494self495}496}497498fn check_item_count_is_one<T, P: Policy>(499values: &[Value<T, P::Count>],500allow_empty: bool,501) -> PolarsResult<()> {502for v in values {503P::check_count(v.count, allow_empty)?;504}505Ok(())506}507508509