Path: blob/main/crates/polars-ops/src/series/ops/cum_agg.rs
8489 views
use std::ops::{AddAssign, Mul};12use arity::unary_elementwise_values;3use arrow::array::{Array, BooleanArray};4use arrow::bitmap::{Bitmap, BitmapBuilder};5use num_traits::{Bounded, One, Zero};6use polars_core::prelude::*;7use polars_core::series::IsSorted;8use polars_core::utils::{CustomIterTools, NoNull};9use polars_core::with_match_physical_numeric_polars_type;10use polars_utils::float::IsFloat;11use polars_utils::min_max::MinMax;1213fn det_max<T>(state: &mut T, v: Option<T>) -> Option<Option<T>>14where15T: Copy + MinMax,16{17match v {18Some(v) => {19*state = MinMax::max_ignore_nan(*state, v);20Some(Some(*state))21},22None => Some(None),23}24}2526fn det_min<T>(state: &mut T, v: Option<T>) -> Option<Option<T>>27where28T: Copy + MinMax,29{30match v {31Some(v) => {32*state = MinMax::min_ignore_nan(*state, v);33Some(Some(*state))34},35None => Some(None),36}37}3839fn det_sum<T>(state: &mut T, v: Option<T>) -> Option<Option<T>>40where41T: Copy + AddAssign,42{43match v {44Some(v) => {45*state += v;46Some(Some(*state))47},48None => Some(None),49}50}5152fn det_prod<T>(state: &mut T, v: Option<T>) -> Option<Option<T>>53where54T: Copy + Mul<Output = T>,55{56match v {57Some(v) => {58*state = *state * v;59Some(Some(*state))60},61None => Some(None),62}63}6465fn cum_scan_numeric<T, F>(66ca: &ChunkedArray<T>,67reverse: bool,68init: T::Native,69update: F,70) -> ChunkedArray<T>71where72T: PolarsNumericType,73ChunkedArray<T>: FromIterator<Option<T::Native>>,74F: Fn(&mut T::Native, Option<T::Native>) -> Option<Option<T::Native>>,75{76let out: ChunkedArray<T> = match reverse {77false => ca.iter().scan(init, update).collect_trusted(),78true => ca.iter().rev().scan(init, update).collect_reversed(),79};80out.with_name(ca.name().clone())81}8283fn cum_max_numeric<T>(84ca: &ChunkedArray<T>,85reverse: bool,86init: Option<T::Native>,87) -> ChunkedArray<T>88where89T: PolarsNumericType,90T::Native: MinMax + Bounded,91ChunkedArray<T>: FromIterator<Option<T::Native>>,92{93let init = init.unwrap_or(if T::Native::is_float() {94T::Native::nan_value()95} else {96Bounded::min_value()97});98cum_scan_numeric(ca, reverse, init, det_max)99}100101fn cum_min_numeric<T>(102ca: &ChunkedArray<T>,103reverse: bool,104init: Option<T::Native>,105) -> ChunkedArray<T>106where107T: PolarsNumericType,108T::Native: MinMax + Bounded,109ChunkedArray<T>: FromIterator<Option<T::Native>>,110{111let init = init.unwrap_or(if T::Native::is_float() {112T::Native::nan_value()113} else {114Bounded::max_value()115});116cum_scan_numeric(ca, reverse, init, det_min)117}118119fn cum_max_bool(ca: &BooleanChunked, reverse: bool, init: Option<bool>) -> BooleanChunked {120if ca.len() == ca.null_count() {121return ca.clone();122}123124if init == Some(true) {125return unsafe {126BooleanChunked::from_chunks(127ca.name().clone(),128ca.downcast_iter()129.map(|arr| {130arr.with_values(Bitmap::new_with_value(true, arr.len()))131.to_boxed()132})133.collect(),134)135};136}137138let mut out;139if !reverse {140// TODO: efficient bitscan.141let Some(first_true_idx) = ca.iter().position(|x| x == Some(true)) else {142return ca.clone();143};144out = BitmapBuilder::with_capacity(ca.len());145out.extend_constant(first_true_idx, false);146out.extend_constant(ca.len() - first_true_idx, true);147} else {148// TODO: efficient bitscan.149let Some(last_true_idx) = ca.iter().rposition(|x| x == Some(true)) else {150return ca.clone();151};152out = BitmapBuilder::with_capacity(ca.len());153out.extend_constant(last_true_idx + 1, true);154out.extend_constant(ca.len() - 1 - last_true_idx, false);155}156157let arr: BooleanArray = out.freeze().into();158BooleanChunked::with_chunk_like(ca, arr.with_validity(ca.rechunk_validity()))159}160161fn cum_min_bool(ca: &BooleanChunked, reverse: bool, init: Option<bool>) -> BooleanChunked {162if ca.len() == ca.null_count() {163return ca.clone();164}165166if init == Some(false) {167return unsafe {168BooleanChunked::from_chunks(169ca.name().clone(),170ca.downcast_iter()171.map(|arr| {172arr.with_values(Bitmap::new_with_value(false, arr.len()))173.to_boxed()174})175.collect(),176)177};178}179180let mut out;181if !reverse {182// TODO: efficient bitscan.183let Some(first_false_idx) = ca.iter().position(|x| x == Some(false)) else {184return ca.clone();185};186out = BitmapBuilder::with_capacity(ca.len());187out.extend_constant(first_false_idx, true);188out.extend_constant(ca.len() - first_false_idx, false);189} else {190// TODO: efficient bitscan.191let Some(last_false_idx) = ca.iter().rposition(|x| x == Some(false)) else {192return ca.clone();193};194out = BitmapBuilder::with_capacity(ca.len());195out.extend_constant(last_false_idx + 1, false);196out.extend_constant(ca.len() - 1 - last_false_idx, true);197}198199let arr: BooleanArray = out.freeze().into();200BooleanChunked::with_chunk_like(ca, arr.with_validity(ca.rechunk_validity()))201}202203fn cum_sum_numeric<T>(204ca: &ChunkedArray<T>,205reverse: bool,206init: Option<T::Native>,207) -> ChunkedArray<T>208where209T: PolarsNumericType,210ChunkedArray<T>: FromIterator<Option<T::Native>>,211{212let init = init.unwrap_or(T::Native::zero());213cum_scan_numeric(ca, reverse, init, det_sum)214}215216#[cfg(feature = "dtype-decimal")]217fn cum_sum_decimal(218ca: &Int128Chunked,219reverse: bool,220init: Option<i128>,221) -> PolarsResult<Int128Chunked> {222use polars_compute::decimal::{DEC128_MAX_PREC, dec128_add};223224let mut value = init.unwrap_or(0);225let update = |opt_v| {226if let Some(v) = opt_v {227value = dec128_add(value, v, DEC128_MAX_PREC).ok_or_else(228|| polars_err!(ComputeError: "overflow in decimal addition in cum_sum"),229)?;230Ok(Some(value))231} else {232Ok(None)233}234};235if reverse {236ca.iter().rev().map(update).try_collect_ca_trusted_like(ca)237} else {238ca.iter().map(update).try_collect_ca_trusted_like(ca)239}240}241242fn cum_prod_numeric<T>(243ca: &ChunkedArray<T>,244reverse: bool,245init: Option<T::Native>,246) -> ChunkedArray<T>247where248T: PolarsNumericType,249ChunkedArray<T>: FromIterator<Option<T::Native>>,250{251let init = init.unwrap_or(T::Native::one());252cum_scan_numeric(ca, reverse, init, det_prod)253}254255pub fn cum_prod_with_init(256s: &Series,257reverse: bool,258init: &AnyValue<'static>,259) -> PolarsResult<Series> {260use DataType::*;261let out = match s.dtype() {262Boolean | Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 => {263let s = s.cast(&Int64)?;264cum_prod_numeric(s.i64()?, reverse, init.extract()).into_series()265},266Int64 => cum_prod_numeric(s.i64()?, reverse, init.extract()).into_series(),267UInt64 => cum_prod_numeric(s.u64()?, reverse, init.extract()).into_series(),268#[cfg(feature = "dtype-i128")]269Int128 => cum_prod_numeric(s.i128()?, reverse, init.extract()).into_series(),270#[cfg(feature = "dtype-u128")]271UInt128 => cum_prod_numeric(s.u128()?, reverse, init.extract()).into_series(),272#[cfg(feature = "dtype-f16")]273Float16 => cum_prod_numeric(s.f16()?, reverse, init.extract()).into_series(),274Float32 => cum_prod_numeric(s.f32()?, reverse, init.extract()).into_series(),275Float64 => cum_prod_numeric(s.f64()?, reverse, init.extract()).into_series(),276dt => polars_bail!(opq = cum_prod, dt),277};278Ok(out)279}280281/// Get an array with the cumulative product computed at every element.282///283/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16, Int32, UInt32}` the `Series` is284/// first cast to `Int64` to prevent overflow issues.285pub fn cum_prod(s: &Series, reverse: bool) -> PolarsResult<Series> {286cum_prod_with_init(s, reverse, &AnyValue::Null)287}288289pub fn cum_sum_with_init(290s: &Series,291reverse: bool,292init: &AnyValue<'static>,293) -> PolarsResult<Series> {294use DataType::*;295let out = match s.dtype() {296Boolean => {297let s = s.cast(&UInt32)?;298cum_sum_numeric(s.u32()?, reverse, init.extract()).into_series()299},300Int8 | UInt8 | Int16 | UInt16 => {301let s = s.cast(&Int64)?;302cum_sum_numeric(s.i64()?, reverse, init.extract()).into_series()303},304Int32 => cum_sum_numeric(s.i32()?, reverse, init.extract()).into_series(),305UInt32 => cum_sum_numeric(s.u32()?, reverse, init.extract()).into_series(),306Int64 => cum_sum_numeric(s.i64()?, reverse, init.extract()).into_series(),307UInt64 => cum_sum_numeric(s.u64()?, reverse, init.extract()).into_series(),308#[cfg(feature = "dtype-u128")]309UInt128 => cum_sum_numeric(s.u128()?, reverse, init.extract()).into_series(),310#[cfg(feature = "dtype-i128")]311Int128 => cum_sum_numeric(s.i128()?, reverse, init.extract()).into_series(),312#[cfg(feature = "dtype-f16")]313Float16 => cum_sum_numeric(s.f16()?, reverse, init.extract()).into_series(),314Float32 => cum_sum_numeric(s.f32()?, reverse, init.extract()).into_series(),315Float64 => cum_sum_numeric(s.f64()?, reverse, init.extract()).into_series(),316#[cfg(feature = "dtype-decimal")]317Decimal(_precision, scale) => {318use polars_compute::decimal::DEC128_MAX_PREC;319let ca = s.decimal().unwrap().physical();320cum_sum_decimal(ca, reverse, init.clone().to_physical().extract())?321.into_decimal_unchecked(DEC128_MAX_PREC, *scale)322.into_series()323},324#[cfg(feature = "dtype-duration")]325Duration(tu) => {326let s = s.to_physical_repr();327let ca = s.i64()?;328cum_sum_numeric(ca, reverse, init.extract()).cast(&Duration(*tu))?329},330dt => polars_bail!(opq = cum_sum, dt),331};332Ok(out)333}334335/// Get an array with the cumulative sum computed at every element336///337/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is338/// first cast to `Int64` to prevent overflow issues.339pub fn cum_sum(s: &Series, reverse: bool) -> PolarsResult<Series> {340cum_sum_with_init(s, reverse, &AnyValue::Null)341}342343pub fn cum_min_with_init(344s: &Series,345reverse: bool,346init: &AnyValue<'static>,347) -> PolarsResult<Series> {348match s.dtype() {349DataType::Boolean => {350Ok(cum_min_bool(s.bool()?, reverse, init.extract_bool()).into_series())351},352#[cfg(feature = "dtype-decimal")]353DataType::Decimal(precision, scale) => {354let ca = s.decimal().unwrap().physical();355let out = cum_min_numeric(ca, reverse, init.clone().to_physical().extract())356.into_decimal_unchecked(*precision, *scale)357.into_series();358Ok(out)359},360dt if dt.to_physical().is_primitive_numeric() => {361let s = s.to_physical_repr();362with_match_physical_numeric_polars_type!(s.dtype(), |$T| {363let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();364let out = cum_min_numeric(ca, reverse, init.extract()).into_series();365if dt.is_logical() {366out.cast(dt)367} else {368Ok(out)369}370})371},372dt => polars_bail!(opq = cum_min, dt),373}374}375376/// Get an array with the cumulative min computed at every element.377pub fn cum_min(s: &Series, reverse: bool) -> PolarsResult<Series> {378cum_min_with_init(s, reverse, &AnyValue::Null)379}380381pub fn cum_max_with_init(382s: &Series,383reverse: bool,384init: &AnyValue<'static>,385) -> PolarsResult<Series> {386match s.dtype() {387DataType::Boolean => {388Ok(cum_max_bool(s.bool()?, reverse, init.extract_bool()).into_series())389},390#[cfg(feature = "dtype-decimal")]391DataType::Decimal(precision, scale) => {392let ca = s.decimal().unwrap().physical();393let out = cum_max_numeric(ca, reverse, init.clone().to_physical().extract())394.into_decimal_unchecked(*precision, *scale)395.into_series();396Ok(out)397},398dt if dt.to_physical().is_primitive_numeric() => {399let s = s.to_physical_repr();400with_match_physical_numeric_polars_type!(s.dtype(), |$T| {401let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();402let out = cum_max_numeric(ca, reverse, init.extract()).into_series();403if dt.is_logical() {404out.cast(dt)405} else {406Ok(out)407}408})409},410dt => polars_bail!(opq = cum_max, dt),411}412}413414/// Get an array with the cumulative max computed at every element.415pub fn cum_max(s: &Series, reverse: bool) -> PolarsResult<Series> {416cum_max_with_init(s, reverse, &AnyValue::Null)417}418419pub fn cum_count(s: &Series, reverse: bool) -> PolarsResult<Series> {420cum_count_with_init(s, reverse, 0)421}422423pub fn cum_count_with_init(s: &Series, reverse: bool, init: IdxSize) -> PolarsResult<Series> {424let mut out = if s.null_count() == 0 {425// Fast paths for no nulls426cum_count_no_nulls(s.name().clone(), s.len(), reverse, init)427} else {428let ca = s.is_not_null();429let out: IdxCa = if reverse {430let mut count = init + (s.len() - s.null_count()) as IdxSize;431let mut prev = false;432unary_elementwise_values(&ca, |v: bool| {433if prev {434count -= 1;435}436prev = v;437count438})439} else {440let mut count = init;441unary_elementwise_values(&ca, |v: bool| {442if v {443count += 1;444}445count446})447};448449out.into()450};451452out.set_sorted_flag([IsSorted::Ascending, IsSorted::Descending][reverse as usize]);453454Ok(out)455}456457fn cum_count_no_nulls(name: PlSmallStr, len: usize, reverse: bool, init: IdxSize) -> Series {458let start = 1 as IdxSize;459let end = len as IdxSize + 1;460let ca: NoNull<IdxCa> = if reverse {461(start..end).rev().map(|v| v + init).collect()462} else {463(start..end).map(|v| v + init).collect()464};465let mut ca = ca.into_inner();466ca.rename(name);467ca.into_series()468}469470471