Path: blob/main/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use arrow::array::Array;2use arrow::legacy::kernels::take_agg::{3take_agg_no_null_primitive_iter_unchecked, take_agg_primitive_iter_unchecked,4};5use polars_compute::rolling;6use polars_compute::rolling::no_nulls::{MaxWindow, MinWindow};7use polars_core::frame::group_by::aggregations::{8_agg_helper_idx, _agg_helper_slice, _rolling_apply_agg_window_no_nulls,9_rolling_apply_agg_window_nulls, _slice_from_offsets, _use_rolling_kernels,10};11use polars_core::prelude::*;12use polars_utils::min_max::MinMax;1314pub fn ca_nan_agg<T, Agg>(ca: &ChunkedArray<T>, min_or_max_fn: Agg) -> Option<T::Native>15where16T: PolarsFloatType,17Agg: Fn(T::Native, T::Native) -> T::Native + Copy,18{19ca.downcast_iter()20.filter_map(|arr| {21if arr.null_count() == 0 {22arr.values().iter().copied().reduce(min_or_max_fn)23} else {24arr.iter()25.unwrap_optional()26.filter_map(|opt| opt.copied())27.reduce(min_or_max_fn)28}29})30.reduce(min_or_max_fn)31}3233pub fn nan_min_s(s: &Series, name: PlSmallStr) -> Series {34match s.dtype() {35DataType::Float32 => {36let ca = s.f32().unwrap();37Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])38},39DataType::Float64 => {40let ca = s.f64().unwrap();41Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])42},43_ => panic!("expected float"),44}45}4647pub fn nan_max_s(s: &Series, name: PlSmallStr) -> Series {48match s.dtype() {49DataType::Float32 => {50let ca = s.f32().unwrap();51Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])52},53DataType::Float64 => {54let ca = s.f64().unwrap();55Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])56},57_ => panic!("expected float"),58}59}6061unsafe fn group_nan_max<T: PolarsFloatType>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series {62match groups {63GroupsType::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {64debug_assert!(idx.len() <= ca.len());65if idx.is_empty() {66None67} else if idx.len() == 1 {68ca.get(first as usize)69} else {70match (ca.has_nulls(), ca.chunks().len()) {71(false, 1) => take_agg_no_null_primitive_iter_unchecked(72ca.downcast_iter().next().unwrap(),73idx.iter().map(|i| *i as usize),74MinMax::max_propagate_nan,75),76(_, 1) => take_agg_primitive_iter_unchecked(77ca.downcast_iter().next().unwrap(),78idx.iter().map(|i| *i as usize),79MinMax::max_propagate_nan,80),81_ => {82let take = { ca.take_unchecked(idx) };83ca_nan_agg(&take, MinMax::max_propagate_nan)84},85}86}87}),88GroupsType::Slice {89groups: groups_slice,90..91} => {92if _use_rolling_kernels(groups_slice, ca.chunks()) {93let arr = ca.downcast_iter().next().unwrap();94let values = arr.values().as_slice();95let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));96let arr = match arr.validity() {97None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(98values,99offset_iter,100None,101),102Some(validity) => _rolling_apply_agg_window_nulls::<103rolling::nulls::MaxWindow<_>,104_,105_,106>(values, validity, offset_iter, None),107};108ChunkedArray::<T>::from(arr).into_series()109} else {110_agg_helper_slice::<T, _>(groups_slice, |[first, len]| {111debug_assert!(len <= ca.len() as IdxSize);112match len {1130 => None,1141 => ca.get(first as usize),115_ => {116let arr_group = _slice_from_offsets(ca, first, len);117ca_nan_agg(&arr_group, MinMax::max_propagate_nan)118},119}120})121}122},123}124}125126unsafe fn group_nan_min<T: PolarsFloatType>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series {127match groups {128GroupsType::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {129debug_assert!(idx.len() <= ca.len());130if idx.is_empty() {131None132} else if idx.len() == 1 {133ca.get(first as usize)134} else {135match (ca.has_nulls(), ca.chunks().len()) {136(false, 1) => take_agg_no_null_primitive_iter_unchecked(137ca.downcast_iter().next().unwrap(),138idx.iter().map(|i| *i as usize),139MinMax::min_propagate_nan,140),141(_, 1) => take_agg_primitive_iter_unchecked(142ca.downcast_iter().next().unwrap(),143idx.iter().map(|i| *i as usize),144MinMax::min_propagate_nan,145),146_ => {147let take = { ca.take_unchecked(idx) };148ca_nan_agg(&take, MinMax::min_propagate_nan)149},150}151}152}),153GroupsType::Slice {154groups: groups_slice,155..156} => {157if _use_rolling_kernels(groups_slice, ca.chunks()) {158let arr = ca.downcast_iter().next().unwrap();159let values = arr.values().as_slice();160let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));161let arr = match arr.validity() {162None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(163values,164offset_iter,165None,166),167Some(validity) => _rolling_apply_agg_window_nulls::<168rolling::nulls::MinWindow<_>,169_,170_,171>(values, validity, offset_iter, None),172};173ChunkedArray::<T>::from(arr).into_series()174} else {175_agg_helper_slice::<T, _>(groups_slice, |[first, len]| {176debug_assert!(len <= ca.len() as IdxSize);177match len {1780 => None,1791 => ca.get(first as usize),180_ => {181let arr_group = _slice_from_offsets(ca, first, len);182ca_nan_agg(&arr_group, MinMax::min_propagate_nan)183},184}185})186}187},188}189}190191/// # Safety192/// `groups` must be in bounds.193pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsType) -> Series {194match s.dtype() {195DataType::Float32 => {196let ca = s.f32().unwrap();197group_nan_min(ca, groups)198},199DataType::Float64 => {200let ca = s.f64().unwrap();201group_nan_min(ca, groups)202},203_ => panic!("expected float"),204}205}206207/// # Safety208/// `groups` must be in bounds.209pub unsafe fn group_agg_nan_max_s(s: &Series, groups: &GroupsType) -> Series {210match s.dtype() {211DataType::Float32 => {212let ca = s.f32().unwrap();213group_nan_max(ca, groups)214},215DataType::Float64 => {216let ca = s.f64().unwrap();217group_nan_max(ca, groups)218},219_ => panic!("expected float"),220}221}222223224