Path: blob/main/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs
8420 views
#![allow(unsafe_op_in_unsafe_fn)]1use arrow::array::Array;2use arrow::legacy::kernels::take_agg::{3take_agg_no_null_primitive_iter_unchecked, take_agg_primitive_iter_unchecked,4};5use polars_compute::rolling;6use polars_compute::rolling::no_nulls::{MaxWindow, MinWindow};7use polars_core::frame::group_by::aggregations::{8_agg_helper_idx, _agg_helper_slice, _rolling_apply_agg_window_no_nulls,9_rolling_apply_agg_window_nulls, _slice_from_offsets, _use_rolling_kernels,10};11use polars_core::prelude::*;12use polars_utils::min_max::MinMax;1314pub fn ca_nan_agg<T, Agg>(ca: &ChunkedArray<T>, min_or_max_fn: Agg) -> Option<T::Native>15where16T: PolarsFloatType,17Agg: Fn(T::Native, T::Native) -> T::Native + Copy,18{19ca.downcast_iter()20.filter_map(|arr| {21if arr.null_count() == 0 {22arr.values().iter().copied().reduce(min_or_max_fn)23} else {24arr.iter()25.unwrap_optional()26.filter_map(|opt| opt.copied())27.reduce(min_or_max_fn)28}29})30.reduce(min_or_max_fn)31}3233pub fn nan_min_s(s: &Series, name: PlSmallStr) -> Series {34match s.dtype() {35#[cfg(feature = "dtype-f16")]36DataType::Float16 => {37let ca = s.f16().unwrap();38Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])39},40DataType::Float32 => {41let ca = s.f32().unwrap();42Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])43},44DataType::Float64 => {45let ca = s.f64().unwrap();46Series::new(name, [ca_nan_agg(ca, MinMax::min_propagate_nan)])47},48_ => panic!("expected float"),49}50}5152pub fn nan_max_s(s: &Series, name: PlSmallStr) -> Series {53match s.dtype() {54#[cfg(feature = "dtype-f16")]55DataType::Float16 => {56let ca = s.f16().unwrap();57Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])58},59DataType::Float32 => {60let ca = s.f32().unwrap();61Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])62},63DataType::Float64 => {64let ca = s.f64().unwrap();65Series::new(name, [ca_nan_agg(ca, MinMax::max_propagate_nan)])66},67_ => panic!("expected float"),68}69}7071unsafe fn group_nan_max<T: PolarsFloatType>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series {72match groups {73GroupsType::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {74debug_assert!(idx.len() <= ca.len());75if idx.is_empty() {76None77} else if idx.len() == 1 {78ca.get(first as usize)79} else {80match (ca.has_nulls(), ca.chunks().len()) {81(false, 1) => take_agg_no_null_primitive_iter_unchecked(82ca.downcast_iter().next().unwrap(),83idx.iter().map(|i| *i as usize),84)85.reduce(MinMax::max_propagate_nan),86(_, 1) => take_agg_primitive_iter_unchecked(87ca.downcast_iter().next().unwrap(),88idx.iter().map(|i| *i as usize),89)90.reduce(MinMax::max_propagate_nan),91_ => {92let take = { ca.take_unchecked(idx) };93ca_nan_agg(&take, MinMax::max_propagate_nan)94},95}96}97}),98GroupsType::Slice {99groups: groups_slice,100overlapping,101monotonic,102} => {103if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, ca.chunks()) {104let arr = ca.downcast_iter().next().unwrap();105let values = arr.values().as_slice();106let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));107let arr = match arr.validity() {108None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _, _>(109values,110offset_iter,111None,112),113Some(validity) => _rolling_apply_agg_window_nulls::<114rolling::nulls::MaxWindow<_>,115_,116_,117_,118>(values, validity, offset_iter, None),119};120ChunkedArray::<T>::from(arr).into_series()121} else {122_agg_helper_slice::<T, _>(groups_slice, |[first, len]| {123debug_assert!(len <= ca.len() as IdxSize);124match len {1250 => None,1261 => ca.get(first as usize),127_ => {128let arr_group = _slice_from_offsets(ca, first, len);129ca_nan_agg(&arr_group, MinMax::max_propagate_nan)130},131}132})133}134},135}136}137138unsafe fn group_nan_min<T: PolarsFloatType>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series {139match groups {140GroupsType::Idx(groups) => _agg_helper_idx::<T, _>(groups, |(first, idx)| {141debug_assert!(idx.len() <= ca.len());142if idx.is_empty() {143None144} else if idx.len() == 1 {145ca.get(first as usize)146} else {147match (ca.has_nulls(), ca.chunks().len()) {148(false, 1) => take_agg_no_null_primitive_iter_unchecked(149ca.downcast_iter().next().unwrap(),150idx.iter().map(|i| *i as usize),151)152.reduce(MinMax::min_propagate_nan),153(_, 1) => take_agg_primitive_iter_unchecked(154ca.downcast_iter().next().unwrap(),155idx.iter().map(|i| *i as usize),156)157.reduce(MinMax::min_propagate_nan),158_ => {159let take = { ca.take_unchecked(idx) };160ca_nan_agg(&take, MinMax::min_propagate_nan)161},162}163}164}),165GroupsType::Slice {166groups: groups_slice,167overlapping,168monotonic,169} => {170if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, ca.chunks()) {171let arr = ca.downcast_iter().next().unwrap();172let values = arr.values().as_slice();173let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));174let arr = match arr.validity() {175None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _, _>(176values,177offset_iter,178None,179),180Some(validity) => _rolling_apply_agg_window_nulls::<181rolling::nulls::MinWindow<_>,182_,183_,184_,185>(values, validity, offset_iter, None),186};187ChunkedArray::<T>::from(arr).into_series()188} else {189_agg_helper_slice::<T, _>(groups_slice, |[first, len]| {190debug_assert!(len <= ca.len() as IdxSize);191match len {1920 => None,1931 => ca.get(first as usize),194_ => {195let arr_group = _slice_from_offsets(ca, first, len);196ca_nan_agg(&arr_group, MinMax::min_propagate_nan)197},198}199})200}201},202}203}204205/// # Safety206/// `groups` must be in bounds.207pub unsafe fn group_agg_nan_min_s(s: &Series, groups: &GroupsType) -> Series {208match s.dtype() {209#[cfg(feature = "dtype-f16")]210DataType::Float16 => {211let ca = s.f16().unwrap();212group_nan_min(ca, groups)213},214DataType::Float32 => {215let ca = s.f32().unwrap();216group_nan_min(ca, groups)217},218DataType::Float64 => {219let ca = s.f64().unwrap();220group_nan_min(ca, groups)221},222_ => panic!("expected float"),223}224}225226/// # Safety227/// `groups` must be in bounds.228pub unsafe fn group_agg_nan_max_s(s: &Series, groups: &GroupsType) -> Series {229match s.dtype() {230#[cfg(feature = "dtype-f16")]231DataType::Float16 => {232let ca = s.f16().unwrap();233group_nan_max(ca, groups)234},235DataType::Float32 => {236let ca = s.f32().unwrap();237group_nan_max(ca, groups)238},239DataType::Float64 => {240let ca = s.f64().unwrap();241group_nan_max(ca, groups)242},243_ => panic!("expected float"),244}245}246247248