Path: blob/main/crates/polars-compute/src/if_then_else/simd.rs
6939 views
#[cfg(target_arch = "x86_64")]1use std::mem::MaybeUninit;2#[cfg(target_arch = "x86_64")]3use std::simd::{Mask, Simd, SimdElement};45use arrow::array::PrimitiveArray;6use arrow::bitmap::Bitmap;7use arrow::datatypes::ArrowDataType;89use super::{10IfThenElseKernel, if_then_else_loop, if_then_else_loop_broadcast_both,11if_then_else_loop_broadcast_false, if_then_else_validity, scalar,12};1314#[cfg(target_arch = "x86_64")]15fn select_simd_64<T: Copy + SimdElement>(16mask: u64,17if_true: Simd<T, 64>,18if_false: Simd<T, 64>,19out: &mut [MaybeUninit<T>; 64],20) {21let mv = Mask::<<T as SimdElement>::Mask, 64>::from_bitmask(mask);22let ret = mv.select(if_true, if_false);23unsafe {24let src = ret.as_array().as_ptr() as *const MaybeUninit<T>;25core::ptr::copy_nonoverlapping(src, out.as_mut_ptr(), 64);26}27}2829#[cfg(target_arch = "x86_64")]30fn if_then_else_simd_64<T: Copy + SimdElement>(31mask: u64,32if_true: &[T; 64],33if_false: &[T; 64],34out: &mut [MaybeUninit<T>; 64],35) {36select_simd_64(37mask,38Simd::from_slice(if_true),39Simd::from_slice(if_false),40out,41)42}4344#[cfg(target_arch = "x86_64")]45fn if_then_else_broadcast_false_simd_64<T: Copy + SimdElement>(46mask: u64,47if_true: &[T; 64],48if_false: T,49out: &mut [MaybeUninit<T>; 64],50) {51select_simd_64(mask, Simd::from_slice(if_true), Simd::splat(if_false), out)52}5354#[cfg(target_arch = "x86_64")]55fn if_then_else_broadcast_both_simd_64<T: Copy + SimdElement>(56mask: u64,57if_true: T,58if_false: T,59out: &mut [MaybeUninit<T>; 64],60) {61select_simd_64(mask, Simd::splat(if_true), Simd::splat(if_false), out)62}6364macro_rules! impl_if_then_else {65($T: ty) => {66impl IfThenElseKernel for PrimitiveArray<$T> {67type Scalar<'a> = $T;6869fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {70let values = if_then_else_loop(71mask,72if_true.values(),73if_false.values(),74scalar::if_then_else_scalar_rest,75// Auto-generated SIMD was slower on ARM.76#[cfg(target_arch = "x86_64")]77if_then_else_simd_64,78#[cfg(not(target_arch = "x86_64"))]79scalar::if_then_else_scalar_64,80);81let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity());82PrimitiveArray::from_vec(values).with_validity(validity)83}8485fn if_then_else_broadcast_true(86mask: &Bitmap,87if_true: Self::Scalar<'_>,88if_false: &Self,89) -> Self {90let values = if_then_else_loop_broadcast_false(91true,92mask,93if_false.values(),94if_true,95// Auto-generated SIMD was slower on ARM.96#[cfg(target_arch = "x86_64")]97if_then_else_broadcast_false_simd_64,98#[cfg(not(target_arch = "x86_64"))]99scalar::if_then_else_broadcast_false_scalar_64,100);101let validity = if_then_else_validity(mask, None, if_false.validity());102PrimitiveArray::from_vec(values).with_validity(validity)103}104105fn if_then_else_broadcast_false(106mask: &Bitmap,107if_true: &Self,108if_false: Self::Scalar<'_>,109) -> Self {110let values = if_then_else_loop_broadcast_false(111false,112mask,113if_true.values(),114if_false,115// Auto-generated SIMD was slower on ARM.116#[cfg(target_arch = "x86_64")]117if_then_else_broadcast_false_simd_64,118#[cfg(not(target_arch = "x86_64"))]119scalar::if_then_else_broadcast_false_scalar_64,120);121let validity = if_then_else_validity(mask, if_true.validity(), None);122PrimitiveArray::from_vec(values).with_validity(validity)123}124125fn if_then_else_broadcast_both(126_dtype: ArrowDataType,127mask: &Bitmap,128if_true: Self::Scalar<'_>,129if_false: Self::Scalar<'_>,130) -> Self {131let values = if_then_else_loop_broadcast_both(132mask,133if_true,134if_false,135// Auto-generated SIMD was slower on ARM.136#[cfg(target_arch = "x86_64")]137if_then_else_broadcast_both_simd_64,138#[cfg(not(target_arch = "x86_64"))]139scalar::if_then_else_broadcast_both_scalar_64,140);141PrimitiveArray::from_vec(values)142}143}144};145}146147impl_if_then_else!(i8);148impl_if_then_else!(i16);149impl_if_then_else!(i32);150impl_if_then_else!(i64);151impl_if_then_else!(u8);152impl_if_then_else!(u16);153impl_if_then_else!(u32);154impl_if_then_else!(u64);155impl_if_then_else!(f32);156impl_if_then_else!(f64);157158159