Path: blob/main/crates/polars-expr/src/dispatch/groups_dispatch.rs
7884 views
use std::borrow::Cow;1use std::sync::Arc;23use arrow::array::PrimitiveArray;4use arrow::bitmap::Bitmap;5use arrow::bitmap::bitmask::BitMask;6use arrow::trusted_len::TrustMyLength;7use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};8use polars_core::POOL;9use polars_core::error::{PolarsResult, polars_bail, polars_ensure};10use polars_core::frame::DataFrame;11use polars_core::prelude::row_encode::encode_rows_unordered;12use polars_core::prelude::{13AnyValue, ChunkCast, Column, CompatLevel, Float64Chunked, GroupPositions, GroupsType,14IDX_DTYPE, IntoColumn,15};16use polars_core::scalar::Scalar;17use polars_core::series::{ChunkCompareEq, Series};18use polars_utils::itertools::Itertools;19use polars_utils::pl_str::PlSmallStr;20use polars_utils::{IdxSize, UnitVec};21use rayon::iter::{IntoParallelIterator, ParallelIterator};2223use crate::prelude::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};24use crate::state::ExecutionState;2526pub fn reverse<'a>(27inputs: &[Arc<dyn PhysicalExpr>],28df: &DataFrame,29groups: &'a GroupPositions,30state: &ExecutionState,31) -> PolarsResult<AggregationContext<'a>> {32assert_eq!(inputs.len(), 1);3334let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;3536// Length preserving operation on scalars keeps scalar.37if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &ac.agg_state() {38return Ok(ac);39}4041POOL.install(|| {42let positions = GroupsType::Idx(match &**ac.groups().as_ref() {43GroupsType::Idx(idx) => idx44.into_par_iter()45.map(|(first, idx)| {46(47idx.last().copied().unwrap_or(first),48idx.iter().copied().rev().collect(),49)50})51.collect(),52GroupsType::Slice {53groups,54overlapping: _,55monotonic: _,56} => groups57.into_par_iter()58.map(|[start, len]| {59(60start + len.saturating_sub(1),61(*start..*start + *len).rev().collect(),62)63})64.collect(),65})66.into_sliceable();67ac.with_groups(positions);68});6970Ok(ac)71}7273pub fn null_count<'a>(74inputs: &[Arc<dyn PhysicalExpr>],75df: &DataFrame,76groups: &'a GroupPositions,77state: &ExecutionState,78) -> PolarsResult<AggregationContext<'a>> {79assert_eq!(inputs.len(), 1);8081let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;8283if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {84*s = s.is_null().cast(&IDX_DTYPE).unwrap().into_column();85return Ok(ac);86}8788ac.groups();89let values = ac.flat_naive();90let name = values.name().clone();91let Some(validity) = values.rechunk_validity() else {92ac.state = AggState::AggregatedScalar(Column::new_scalar(93name,94(0 as IdxSize).into(),95groups.len(),96));97return Ok(ac);98};99100POOL.install(|| {101let validity = BitMask::from_bitmap(&validity);102let null_count: Vec<IdxSize> = match &**ac.groups.as_ref() {103GroupsType::Idx(idx) => idx104.into_par_iter()105.map(|(_, idx)| {106idx.iter()107.map(|i| IdxSize::from(!unsafe { validity.get_bit_unchecked(*i as usize) }))108.sum::<IdxSize>()109})110.collect(),111GroupsType::Slice {112groups,113overlapping: _,114monotonic: _,115} => groups116.into_par_iter()117.map(|[start, length]| {118unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }119.unset_bits() as IdxSize120})121.collect(),122};123124ac.state = AggState::AggregatedScalar(Column::new(name, null_count));125});126127Ok(ac)128}129130pub fn any<'a>(131inputs: &[Arc<dyn PhysicalExpr>],132df: &DataFrame,133groups: &'a GroupPositions,134state: &ExecutionState,135ignore_nulls: bool,136) -> PolarsResult<AggregationContext<'a>> {137assert_eq!(inputs.len(), 1);138139let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;140141if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {142if ignore_nulls {143*s = s144.equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))145.unwrap()146.into_column();147} else {148*s = s149.equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))150.unwrap()151.into_column();152}153return Ok(ac);154}155156ac.groups();157let values = ac.flat_naive();158let values = values.bool()?;159let out = unsafe { values.agg_any(ac.groups.as_ref(), ignore_nulls) };160ac.state = AggState::AggregatedScalar(out.into_column());161162Ok(ac)163}164165pub fn all<'a>(166inputs: &[Arc<dyn PhysicalExpr>],167df: &DataFrame,168groups: &'a GroupPositions,169state: &ExecutionState,170ignore_nulls: bool,171) -> PolarsResult<AggregationContext<'a>> {172assert_eq!(inputs.len(), 1);173174let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;175176if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {177if ignore_nulls {178*s = s179.equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))180.unwrap()181.into_column();182} else {183*s = s184.equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))185.unwrap()186.into_column();187}188return Ok(ac);189}190191ac.groups();192let values = ac.flat_naive();193let values = values.bool()?;194let out = unsafe { values.agg_all(ac.groups.as_ref(), ignore_nulls) };195ac.state = AggState::AggregatedScalar(out.into_column());196197Ok(ac)198}199200#[cfg(feature = "bitwise")]201pub fn bitwise_agg<'a>(202inputs: &[Arc<dyn PhysicalExpr>],203df: &DataFrame,204groups: &'a GroupPositions,205state: &ExecutionState,206op: &'static str,207f: impl Fn(&Column, &GroupsType) -> Column,208) -> PolarsResult<AggregationContext<'a>> {209assert_eq!(inputs.len(), 1);210211let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;212213if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &ac.state {214let dtype = s.dtype();215polars_ensure!(216dtype.is_bool() | dtype.is_primitive_numeric(),217op = op,218dtype219);220return Ok(ac);221}222223ac.groups();224let values = ac.flat_naive();225let out = f(values.as_ref(), ac.groups.as_ref());226ac.state = AggState::AggregatedScalar(out.into_column());227228Ok(ac)229}230231#[cfg(feature = "bitwise")]232pub fn bitwise_and<'a>(233inputs: &[Arc<dyn PhysicalExpr>],234df: &DataFrame,235groups: &'a GroupPositions,236state: &ExecutionState,237) -> PolarsResult<AggregationContext<'a>> {238bitwise_agg(239inputs,240df,241groups,242state,243"and_reduce",244|v, groups| unsafe { v.agg_and(groups) },245)246}247248#[cfg(feature = "bitwise")]249pub fn bitwise_or<'a>(250inputs: &[Arc<dyn PhysicalExpr>],251df: &DataFrame,252groups: &'a GroupPositions,253state: &ExecutionState,254) -> PolarsResult<AggregationContext<'a>> {255bitwise_agg(inputs, df, groups, state, "or_reduce", |v, groups| unsafe {256v.agg_or(groups)257})258}259260#[cfg(feature = "bitwise")]261pub fn bitwise_xor<'a>(262inputs: &[Arc<dyn PhysicalExpr>],263df: &DataFrame,264groups: &'a GroupPositions,265state: &ExecutionState,266) -> PolarsResult<AggregationContext<'a>> {267bitwise_agg(268inputs,269df,270groups,271state,272"xor_reduce",273|v, groups| unsafe { v.agg_xor(groups) },274)275}276277pub fn drop_items<'a>(278mut ac: AggregationContext<'a>,279predicate: &Bitmap,280) -> PolarsResult<AggregationContext<'a>> {281// No elements are filtered out.282if predicate.unset_bits() == 0 {283if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {284*c = c.as_list().into_column();285if c.len() == 1 && ac.groups.len() != 1 {286*c = c.new_from_index(0, ac.groups.len());287}288ac.state = AggState::AggregatedList(std::mem::take(c));289ac.update_groups = UpdateGroups::WithSeriesLen;290}291return Ok(ac);292}293294ac.set_original_len(false);295296// All elements are filtered out.297if predicate.set_bits() == 0 {298let name = ac.agg_state().name();299let dtype = ac.agg_state().flat_dtype();300301ac.state = AggState::AggregatedList(Column::new_scalar(302name.clone(),303Scalar::new(304dtype.clone().implode(),305AnyValue::List(Series::new_empty(PlSmallStr::EMPTY, dtype)),306),307ac.groups.len(),308));309ac.with_update_groups(UpdateGroups::WithSeriesLen);310return Ok(ac);311}312313if let AggState::AggregatedScalar(c) = &mut ac.state {314ac.state = AggState::NotAggregated(std::mem::take(c));315ac.groups = Cow::Owned(316{317let groups = predicate318.iter()319.enumerate_idx()320.map(|(i, p)| [i, IdxSize::from(p)])321.collect();322GroupsType::new_slice(groups, false, true)323}324.into_sliceable(),325);326ac.update_groups = UpdateGroups::No;327return Ok(ac);328}329330ac.groups();331let predicate = BitMask::from_bitmap(predicate);332POOL.install(|| {333let positions = GroupsType::Idx(match &**ac.groups.as_ref() {334GroupsType::Idx(idxs) => idxs335.into_par_iter()336.map(|(fst, idxs)| {337let out = idxs338.iter()339.copied()340.filter(|i| unsafe { predicate.get_bit_unchecked(*i as usize) })341.collect::<UnitVec<IdxSize>>();342(out.first().copied().unwrap_or(fst), out)343})344.collect(),345GroupsType::Slice {346groups,347overlapping: _,348monotonic: _,349} => groups350.into_par_iter()351.map(|[start, length]| {352let predicate =353unsafe { predicate.sliced_unchecked(*start as usize, *length as usize) };354let num_values = predicate.set_bits();355356if num_values == 0 {357(*start, UnitVec::new())358} else if num_values == 1 {359let item = *start + predicate.leading_zeros() as IdxSize;360let mut out = UnitVec::with_capacity(1);361out.push(item);362(item, out)363} else if num_values == *length as usize {364(*start, (*start..*start + *length).collect())365} else {366let out = unsafe {367TrustMyLength::new(368(0..*length)369.filter(|i| predicate.get_bit_unchecked(*i as usize))370.map(|i| i + *start),371num_values,372)373};374let out = out.collect::<UnitVec<IdxSize>>();375376(out.first().copied().unwrap(), out)377}378})379.collect(),380})381.into_sliceable();382ac.with_groups(positions);383});384385Ok(ac)386}387388pub fn drop_nans<'a>(389inputs: &[Arc<dyn PhysicalExpr>],390df: &DataFrame,391groups: &'a GroupPositions,392state: &ExecutionState,393) -> PolarsResult<AggregationContext<'a>> {394assert_eq!(inputs.len(), 1);395let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;396ac.groups();397let predicate = if ac.agg_state().flat_dtype().is_float() {398let values = ac.flat_naive();399let mut values = values.is_nan().unwrap();400values.rechunk_mut();401values.downcast_as_array().values().clone()402} else {403Bitmap::new_with_value(false, 1)404};405let predicate = !&predicate;406drop_items(ac, &predicate)407}408409pub fn drop_nulls<'a>(410inputs: &[Arc<dyn PhysicalExpr>],411df: &DataFrame,412groups: &'a GroupPositions,413state: &ExecutionState,414) -> PolarsResult<AggregationContext<'a>> {415assert_eq!(inputs.len(), 1);416let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;417ac.groups();418let predicate = ac.flat_naive().as_ref().clone();419let predicate = predicate.rechunk_to_arrow(CompatLevel::newest());420let predicate = predicate421.validity()422.cloned()423.unwrap_or(Bitmap::new_with_value(true, 1));424drop_items(ac, &predicate)425}426427#[cfg(feature = "moment")]428pub fn moment_agg<'a, S: Default>(429inputs: &[Arc<dyn PhysicalExpr>],430df: &DataFrame,431groups: &'a GroupPositions,432state: &ExecutionState,433434insert_one: impl Fn(&mut S, f64) + Send + Sync,435new_from_slice: impl Fn(&PrimitiveArray<f64>, usize, usize) -> S + Send + Sync,436finalize: impl Fn(S) -> Option<f64> + Send + Sync,437) -> PolarsResult<AggregationContext<'a>> {438assert_eq!(inputs.len(), 1);439let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;440441if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {442let ca = s.f64()?;443*s = ca444.iter()445.map(|v| {446v.and_then(|v| {447let mut state = S::default();448insert_one(&mut state, v);449finalize(state)450})451})452.collect::<Float64Chunked>()453.with_name(ca.name().clone())454.into_column();455return Ok(ac);456}457458ac.groups();459460let name = ac.get_values().name().clone();461let ca = ac.flat_naive();462let ca = ca.f64()?;463let ca = ca.rechunk();464let arr = ca.downcast_as_array();465466let ca = POOL.install(|| match &**ac.groups.as_ref() {467GroupsType::Idx(idx) => {468if let Some(validity) = arr.validity().filter(|v| v.unset_bits() > 0) {469idx.into_par_iter()470.map(|(_, idx)| {471let mut state = S::default();472for &i in idx.iter() {473if unsafe { validity.get_bit_unchecked(i as usize) } {474insert_one(&mut state, arr.values()[i as usize]);475}476}477finalize(state)478})479.collect::<Float64Chunked>()480} else {481idx.into_par_iter()482.map(|(_, idx)| {483let mut state = S::default();484for &i in idx.iter() {485insert_one(&mut state, arr.values()[i as usize]);486}487finalize(state)488})489.collect::<Float64Chunked>()490}491},492GroupsType::Slice {493groups,494overlapping: _,495monotonic: _,496} => groups497.into_par_iter()498.map(|[start, length]| finalize(new_from_slice(arr, *start as usize, *length as usize)))499.collect::<Float64Chunked>(),500});501502ac.state = AggState::AggregatedScalar(ca.with_name(name).into_column());503Ok(ac)504}505506#[cfg(feature = "moment")]507pub fn skew<'a>(508inputs: &[Arc<dyn PhysicalExpr>],509df: &DataFrame,510groups: &'a GroupPositions,511state: &ExecutionState,512bias: bool,513) -> PolarsResult<AggregationContext<'a>> {514use polars_compute::moment::SkewState;515moment_agg::<SkewState>(516inputs,517df,518groups,519state,520SkewState::insert_one,521SkewState::from_array,522|s| s.finalize(bias),523)524}525526#[cfg(feature = "moment")]527pub fn kurtosis<'a>(528inputs: &[Arc<dyn PhysicalExpr>],529df: &DataFrame,530groups: &'a GroupPositions,531state: &ExecutionState,532fisher: bool,533bias: bool,534) -> PolarsResult<AggregationContext<'a>> {535use polars_compute::moment::KurtosisState;536moment_agg::<KurtosisState>(537inputs,538df,539groups,540state,541KurtosisState::insert_one,542KurtosisState::from_array,543|s| s.finalize(fisher, bias),544)545}546547pub fn unique<'a>(548inputs: &[Arc<dyn PhysicalExpr>],549df: &DataFrame,550groups: &'a GroupPositions,551state: &ExecutionState,552stable: bool,553) -> PolarsResult<AggregationContext<'a>> {554_ = stable;555556assert_eq!(inputs.len(), 1);557let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;558ac.groups();559560if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {561*c = c.as_list().into_column();562if c.len() == 1 && ac.groups.len() != 1 {563*c = c.new_from_index(0, ac.groups.len());564}565ac.state = AggState::AggregatedList(std::mem::take(c));566ac.update_groups = UpdateGroups::WithSeriesLen;567return Ok(ac);568}569570let values = ac.flat_naive().to_physical_repr();571let dtype = values.dtype();572let values = if dtype.contains_objects() {573polars_bail!(opq = unique, dtype);574} else if let Some(ca) = values.try_str() {575ca.as_binary().into_column()576} else if dtype.is_nested() {577encode_rows_unordered(&[values])?.into_column()578} else {579values580};581582let values = values.rechunk_to_arrow(CompatLevel::newest());583let values = values.as_ref();584let state = amortized_unique_from_dtype(values.dtype());585586struct CloneWrapper(Box<dyn AmortizedUnique>);587impl Clone for CloneWrapper {588fn clone(&self) -> Self {589Self(self.0.new_empty())590}591}592593POOL.install(|| {594let positions = GroupsType::Idx(match &**ac.groups().as_ref() {595GroupsType::Idx(idx) => idx596.into_par_iter()597.map_with(CloneWrapper(state), |state, (first, idx)| {598let mut idx = idx.clone();599unsafe { state.0.retain_unique(values, &mut idx) };600(idx.first().copied().unwrap_or(first), idx)601})602.collect(),603GroupsType::Slice {604groups,605overlapping: _,606monotonic: _,607} => groups608.into_par_iter()609.map_with(CloneWrapper(state), |state, [start, len]| {610let mut idx = UnitVec::new();611state.0.arg_unique(values, &mut idx, *start, *len);612(idx.first().copied().unwrap_or(*start), idx)613})614.collect(),615})616.into_sliceable();617ac.with_groups(positions);618});619620Ok(ac)621}622623fn fw_bw_fill_null<'a>(624inputs: &[Arc<dyn PhysicalExpr>],625df: &DataFrame,626groups: &'a GroupPositions,627state: &ExecutionState,628f_idx: impl Fn(629std::iter::Copied<std::slice::Iter<'_, IdxSize>>,630BitMask<'_>,631usize,632) -> UnitVec<IdxSize>633+ Send634+ Sync,635f_range: impl Fn(std::ops::Range<IdxSize>, BitMask<'_>, usize) -> UnitVec<IdxSize> + Send + Sync,636) -> PolarsResult<AggregationContext<'a>> {637assert_eq!(inputs.len(), 1);638let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;639ac.groups();640641if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &mut ac.state {642return Ok(ac);643}644645let values = ac.flat_naive();646let Some(validity) = values.rechunk_validity() else {647return Ok(ac);648};649650let validity = BitMask::from_bitmap(&validity);651POOL.install(|| {652let positions = GroupsType::Idx(match &**ac.groups().as_ref() {653GroupsType::Idx(idx) => idx654.into_par_iter()655.map(|(first, idx)| {656let idx = f_idx(idx.iter().copied(), validity, idx.len());657(idx.first().copied().unwrap_or(first), idx)658})659.collect(),660GroupsType::Slice {661groups,662overlapping: _,663monotonic: _,664} => groups665.into_par_iter()666.map(|[start, len]| {667let idx = f_range(*start..*start + *len, validity, *len as usize);668(idx.first().copied().unwrap_or(*start), idx)669})670.collect(),671})672.into_sliceable();673ac.with_groups(positions);674});675676Ok(ac)677}678679pub fn forward_fill_null<'a>(680inputs: &[Arc<dyn PhysicalExpr>],681df: &DataFrame,682groups: &'a GroupPositions,683state: &ExecutionState,684limit: Option<IdxSize>,685) -> PolarsResult<AggregationContext<'a>> {686let limit = limit.unwrap_or(IdxSize::MAX);687macro_rules! arg_forward_fill {688(689$iter:ident,690$validity:ident,691$length:ident692) => {{693|$iter, $validity, $length| {694let Some(start) = $iter695.clone()696.position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })697else {698return $iter.collect();699};700701let mut idx = UnitVec::with_capacity($length);702let mut iter = $iter;703idx.extend((&mut iter).take(start));704705let mut current_limit = limit;706let mut value = iter.next().unwrap();707idx.push(value);708709idx.extend(iter.map(|i| {710if unsafe { $validity.get_bit_unchecked(i as usize) } {711current_limit = limit;712value = i;713i714} else if current_limit == 0 {715i716} else {717current_limit -= 1;718value719}720}));721idx722}723}};724}725726fw_bw_fill_null(727inputs,728df,729groups,730state,731arg_forward_fill!(iter, validity, length),732arg_forward_fill!(iter, validity, length),733)734}735736pub fn backward_fill_null<'a>(737inputs: &[Arc<dyn PhysicalExpr>],738df: &DataFrame,739groups: &'a GroupPositions,740state: &ExecutionState,741limit: Option<IdxSize>,742) -> PolarsResult<AggregationContext<'a>> {743let limit = limit.unwrap_or(IdxSize::MAX);744macro_rules! arg_backward_fill {745(746$iter:ident,747$validity:ident,748$length:ident749) => {{750|$iter, $validity, $length| {751let Some(start) = $iter752.clone()753.rev()754.position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })755else {756return $iter.collect();757};758759let mut idx = UnitVec::from_iter($iter);760let mut current_limit = limit;761let mut value = idx[$length - start - 1];762for i in idx[..$length - start].iter_mut().rev() {763if unsafe { $validity.get_bit_unchecked(*i as usize) } {764current_limit = limit;765value = *i;766} else if current_limit != 0 {767current_limit -= 1;768*i = value;769}770}771772idx773}774}};775}776777fw_bw_fill_null(778inputs,779df,780groups,781state,782arg_backward_fill!(iter, validity, length),783arg_backward_fill!(iter, validity, length),784)785}786787788