Path: blob/main/crates/polars-compute/src/min_max/simd.rs
6939 views
use std::simd::prelude::*;1use std::simd::{LaneCount, SimdElement, SupportedLaneCount};23use arrow::array::PrimitiveArray;4use arrow::bitmap::Bitmap;5use arrow::bitmap::bitmask::BitMask;6use arrow::types::NativeType;7use polars_utils::min_max::MinMax;89use super::MinMaxKernel;1011fn scalar_reduce_min_propagate_nan<T: MinMax + Copy, const N: usize>(arr: &[T; N]) -> T {12let it = arr.iter().copied();13it.reduce(MinMax::min_propagate_nan).unwrap()14}1516fn scalar_reduce_max_propagate_nan<T: MinMax + Copy, const N: usize>(arr: &[T; N]) -> T {17let it = arr.iter().copied();18it.reduce(MinMax::max_propagate_nan).unwrap()19}2021fn fold_agg_kernel<const N: usize, T, F>(22arr: &[T],23validity: Option<&Bitmap>,24scalar_identity: T,25mut simd_f: F,26) -> Option<Simd<T, N>>27where28T: SimdElement + NativeType,29F: FnMut(Simd<T, N>, Simd<T, N>) -> Simd<T, N>,30LaneCount<N>: SupportedLaneCount,31{32if arr.is_empty() {33return None;34}3536let mut arr_chunks = arr.chunks_exact(N);3738let identity = Simd::splat(scalar_identity);39let mut state = identity;40if let Some(valid) = validity {41if valid.unset_bits() == arr.len() {42return None;43}4445let mask = BitMask::from_bitmap(valid);46let mut offset = 0;47for c in arr_chunks.by_ref() {48let m: Mask<_, N> = mask.get_simd(offset);49state = simd_f(state, m.select(Simd::from_slice(c), identity));50offset += N;51}52if !arr.len().is_multiple_of(N) {53let mut rest: [T; N] = identity.to_array();54let arr_rest = arr_chunks.remainder();55rest[..arr_rest.len()].copy_from_slice(arr_rest);56let m: Mask<_, N> = mask.get_simd(offset);57state = simd_f(state, m.select(Simd::from_array(rest), identity));58}59} else {60for c in arr_chunks.by_ref() {61state = simd_f(state, Simd::from_slice(c));62}63if !arr.len().is_multiple_of(N) {64let mut rest: [T; N] = identity.to_array();65let arr_rest = arr_chunks.remainder();66rest[..arr_rest.len()].copy_from_slice(arr_rest);67state = simd_f(state, Simd::from_array(rest));68}69}7071Some(state)72}7374fn fold_agg_min_max_kernel<const N: usize, T, F>(75arr: &[T],76validity: Option<&Bitmap>,77min_scalar_identity: T,78max_scalar_identity: T,79mut simd_f: F,80) -> Option<(Simd<T, N>, Simd<T, N>)>81where82T: SimdElement + NativeType,83F: FnMut((Simd<T, N>, Simd<T, N>), (Simd<T, N>, Simd<T, N>)) -> (Simd<T, N>, Simd<T, N>),84LaneCount<N>: SupportedLaneCount,85{86if arr.is_empty() {87return None;88}8990let mut arr_chunks = arr.chunks_exact(N);9192let min_identity = Simd::splat(min_scalar_identity);93let max_identity = Simd::splat(max_scalar_identity);94let mut state = (min_identity, max_identity);95if let Some(valid) = validity {96if valid.unset_bits() == arr.len() {97return None;98}99100let mask = BitMask::from_bitmap(valid);101let mut offset = 0;102for c in arr_chunks.by_ref() {103let m: Mask<_, N> = mask.get_simd(offset);104let slice = Simd::from_slice(c);105state = simd_f(106state,107(m.select(slice, min_identity), m.select(slice, max_identity)),108);109offset += N;110}111if !arr.len().is_multiple_of(N) {112let mut min_rest: [T; N] = min_identity.to_array();113let mut max_rest: [T; N] = max_identity.to_array();114115let arr_rest = arr_chunks.remainder();116min_rest[..arr_rest.len()].copy_from_slice(arr_rest);117max_rest[..arr_rest.len()].copy_from_slice(arr_rest);118119let m: Mask<_, N> = mask.get_simd(offset);120121let min_rest = Simd::from_array(min_rest);122let max_rest = Simd::from_array(max_rest);123124state = simd_f(125state,126(127m.select(min_rest, min_identity),128m.select(max_rest, max_identity),129),130);131}132} else {133for c in arr_chunks.by_ref() {134let slice = Simd::from_slice(c);135state = simd_f(state, (slice, slice));136}137if !arr.len().is_multiple_of(N) {138let mut min_rest: [T; N] = min_identity.to_array();139let mut max_rest: [T; N] = max_identity.to_array();140141let arr_rest = arr_chunks.remainder();142min_rest[..arr_rest.len()].copy_from_slice(arr_rest);143max_rest[..arr_rest.len()].copy_from_slice(arr_rest);144145let min_rest = Simd::from_array(min_rest);146let max_rest = Simd::from_array(max_rest);147148state = simd_f(state, (min_rest, max_rest));149}150}151152Some(state)153}154155macro_rules! impl_min_max_kernel_int {156($T:ty, $N:literal) => {157impl MinMaxKernel for PrimitiveArray<$T> {158type Scalar<'a> = $T;159160fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {161fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::MAX, |a, b| {162a.simd_min(b)163})164.map(|s| s.reduce_min())165}166167fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {168fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::MIN, |a, b| {169a.simd_max(b)170})171.map(|s| s.reduce_max())172}173174fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {175fold_agg_min_max_kernel::<$N, $T, _>(176self.values(),177self.validity(),178<$T>::MAX,179<$T>::MIN,180|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),181)182.map(|(min, max)| (min.reduce_min(), max.reduce_max()))183}184185fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {186self.min_ignore_nan_kernel()187}188189fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {190self.max_ignore_nan_kernel()191}192193fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {194self.min_max_ignore_nan_kernel()195}196}197198impl MinMaxKernel for [$T] {199type Scalar<'a> = $T;200201fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {202fold_agg_kernel::<$N, $T, _>(self, None, <$T>::MAX, |a, b| a.simd_min(b))203.map(|s| s.reduce_min())204}205206fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {207fold_agg_kernel::<$N, $T, _>(self, None, <$T>::MIN, |a, b| a.simd_max(b))208.map(|s| s.reduce_max())209}210211fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {212fold_agg_min_max_kernel::<$N, $T, _>(213self,214None,215<$T>::MAX,216<$T>::MIN,217|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),218)219.map(|(min, max)| (min.reduce_min(), max.reduce_max()))220}221222fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {223self.min_ignore_nan_kernel()224}225226fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {227self.max_ignore_nan_kernel()228}229230fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {231self.min_max_ignore_nan_kernel()232}233}234};235}236237impl_min_max_kernel_int!(u8, 32);238impl_min_max_kernel_int!(u16, 16);239impl_min_max_kernel_int!(u32, 16);240impl_min_max_kernel_int!(u64, 8);241impl_min_max_kernel_int!(i8, 32);242impl_min_max_kernel_int!(i16, 16);243impl_min_max_kernel_int!(i32, 16);244impl_min_max_kernel_int!(i64, 8);245246macro_rules! impl_min_max_kernel_float {247($T:ty, $N:literal) => {248impl MinMaxKernel for PrimitiveArray<$T> {249type Scalar<'a> = $T;250251fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {252fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| {253a.simd_min(b)254})255.map(|s| s.reduce_min())256}257258fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {259fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| {260a.simd_max(b)261})262.map(|s| s.reduce_max())263}264265fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {266fold_agg_min_max_kernel::<$N, $T, _>(267self.values(),268self.validity(),269<$T>::NAN,270<$T>::NAN,271|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),272)273.map(|(min, max)| (min.reduce_min(), max.reduce_max()))274}275276fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {277fold_agg_kernel::<$N, $T, _>(278self.values(),279self.validity(),280<$T>::INFINITY,281|a, b| (a.simd_lt(b) | a.simd_ne(a)).select(a, b),282)283.map(|s| scalar_reduce_min_propagate_nan(s.as_array()))284}285286fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {287fold_agg_kernel::<$N, $T, _>(288self.values(),289self.validity(),290<$T>::NEG_INFINITY,291|a, b| (a.simd_gt(b) | a.simd_ne(a)).select(a, b),292)293.map(|s| scalar_reduce_max_propagate_nan(s.as_array()))294}295296fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {297fold_agg_min_max_kernel::<$N, $T, _>(298self.values(),299self.validity(),300<$T>::INFINITY,301<$T>::NEG_INFINITY,302|(cmin, cmax), (min, max)| {303(304(cmin.simd_lt(min) | cmin.simd_ne(cmin)).select(cmin, min),305(cmax.simd_gt(max) | cmax.simd_ne(cmax)).select(cmax, max),306)307},308)309.map(|(min, max)| {310(311scalar_reduce_min_propagate_nan(min.as_array()),312scalar_reduce_max_propagate_nan(max.as_array()),313)314})315}316}317318impl MinMaxKernel for [$T] {319type Scalar<'a> = $T;320321fn min_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {322fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_min(b))323.map(|s| s.reduce_min())324}325326fn max_ignore_nan_kernel(&self) -> Option<Self::Scalar<'_>> {327fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_max(b))328.map(|s| s.reduce_max())329}330331fn min_max_ignore_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {332fold_agg_min_max_kernel::<$N, $T, _>(333self,334None,335<$T>::NAN,336<$T>::NAN,337|(cmin, cmax), (min, max)| (cmin.simd_min(min), cmax.simd_max(max)),338)339.map(|(min, max)| (min.reduce_min(), max.reduce_max()))340}341342fn min_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {343fold_agg_kernel::<$N, $T, _>(self, None, <$T>::INFINITY, |a, b| {344(a.simd_lt(b) | a.simd_ne(a)).select(a, b)345})346.map(|s| scalar_reduce_min_propagate_nan(s.as_array()))347}348349fn max_propagate_nan_kernel(&self) -> Option<Self::Scalar<'_>> {350fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NEG_INFINITY, |a, b| {351(a.simd_gt(b) | a.simd_ne(a)).select(a, b)352})353.map(|s| scalar_reduce_max_propagate_nan(s.as_array()))354}355356fn min_max_propagate_nan_kernel(&self) -> Option<(Self::Scalar<'_>, Self::Scalar<'_>)> {357fold_agg_min_max_kernel::<$N, $T, _>(358self,359None,360<$T>::INFINITY,361<$T>::NEG_INFINITY,362|(cmin, cmax), (min, max)| {363(364(cmin.simd_lt(min) | cmin.simd_ne(cmin)).select(cmin, min),365(cmax.simd_gt(max) | cmax.simd_ne(cmax)).select(cmax, max),366)367},368)369.map(|(min, max)| {370(371scalar_reduce_min_propagate_nan(min.as_array()),372scalar_reduce_max_propagate_nan(max.as_array()),373)374})375}376}377};378}379380impl_min_max_kernel_float!(f32, 16);381impl_min_max_kernel_float!(f64, 8);382383384