Path: blob/main/crates/polars-compute/src/comparisons/simd.rs
6939 views
use std::ptr;1use std::simd::prelude::{Simd, SimdPartialEq, SimdPartialOrd};23use arrow::array::PrimitiveArray;4use arrow::bitmap::Bitmap;5use arrow::types::NativeType;6use bytemuck::Pod;78use super::{TotalEqKernel, TotalOrdKernel};910fn apply_binary_kernel<const N: usize, M: Pod, T, F>(11lhs: &PrimitiveArray<T>,12rhs: &PrimitiveArray<T>,13mut f: F,14) -> Bitmap15where16T: NativeType,17F: FnMut(&[T; N], &[T; N]) -> M,18{19assert_eq!(N, size_of::<M>() * 8);20assert!(lhs.len() == rhs.len());21let n = lhs.len();2223let lhs_buf = lhs.values().as_slice();24let rhs_buf = rhs.values().as_slice();25let lhs_chunks = lhs_buf.chunks_exact(N);26let rhs_chunks = rhs_buf.chunks_exact(N);27let lhs_rest = lhs_chunks.remainder();28let rhs_rest = rhs_chunks.remainder();2930let num_masks = n.div_ceil(N);31let mut v: Vec<u8> = Vec::with_capacity(num_masks * size_of::<M>());32let mut p = v.as_mut_ptr() as *mut M;33for (l, r) in lhs_chunks.zip(rhs_chunks) {34unsafe {35let mask = f(36l.try_into().unwrap_unchecked(),37r.try_into().unwrap_unchecked(),38);39p.write_unaligned(mask);40p = p.wrapping_add(1);41}42}4344if !n.is_multiple_of(N) {45let mut l: [T; N] = [T::zeroed(); N];46let mut r: [T; N] = [T::zeroed(); N];47unsafe {48ptr::copy_nonoverlapping(lhs_rest.as_ptr(), l.as_mut_ptr(), n % N);49ptr::copy_nonoverlapping(rhs_rest.as_ptr(), r.as_mut_ptr(), n % N);50p.write_unaligned(f(&l, &r));51}52}5354unsafe {55v.set_len(num_masks * size_of::<M>());56}5758Bitmap::from_u8_vec(v, n)59}6061fn apply_unary_kernel<const N: usize, M: Pod, T, F>(arg: &PrimitiveArray<T>, mut f: F) -> Bitmap62where63T: NativeType,64F: FnMut(&[T; N]) -> M,65{66assert_eq!(N, size_of::<M>() * 8);67let n = arg.len();6869let arg_buf = arg.values().as_slice();70let arg_chunks = arg_buf.chunks_exact(N);71let arg_rest = arg_chunks.remainder();7273let num_masks = n.div_ceil(N);74let mut v: Vec<u8> = Vec::with_capacity(num_masks * size_of::<M>());75let mut p = v.as_mut_ptr() as *mut M;76for a in arg_chunks {77unsafe {78let mask = f(a.try_into().unwrap_unchecked());79p.write_unaligned(mask);80p = p.wrapping_add(1);81}82}8384if !n.is_multiple_of(N) {85let mut a: [T; N] = [T::zeroed(); N];86unsafe {87ptr::copy_nonoverlapping(arg_rest.as_ptr(), a.as_mut_ptr(), n % N);88p.write_unaligned(f(&a));89}90}9192unsafe {93v.set_len(num_masks * size_of::<M>());94}9596Bitmap::from_u8_vec(v, n)97}9899macro_rules! impl_int_total_ord_kernel {100($T: ty, $width: literal, $mask: ty) => {101impl TotalEqKernel for PrimitiveArray<$T> {102type Scalar = $T;103104fn tot_eq_kernel(&self, other: &Self) -> Bitmap {105apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {106Simd::from(*l).simd_eq(Simd::from(*r)).to_bitmask() as $mask107})108}109110fn tot_ne_kernel(&self, other: &Self) -> Bitmap {111apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {112Simd::from(*l).simd_ne(Simd::from(*r)).to_bitmask() as $mask113})114}115116fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {117let r = Simd::splat(*other);118apply_unary_kernel::<$width, $mask, _, _>(self, |l| {119Simd::from(*l).simd_eq(r).to_bitmask() as $mask120})121}122123fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {124let r = Simd::splat(*other);125apply_unary_kernel::<$width, $mask, _, _>(self, |l| {126Simd::from(*l).simd_ne(r).to_bitmask() as $mask127})128}129}130131impl TotalOrdKernel for PrimitiveArray<$T> {132type Scalar = $T;133134fn tot_lt_kernel(&self, other: &Self) -> Bitmap {135apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {136Simd::from(*l).simd_lt(Simd::from(*r)).to_bitmask() as $mask137})138}139140fn tot_le_kernel(&self, other: &Self) -> Bitmap {141apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {142Simd::from(*l).simd_le(Simd::from(*r)).to_bitmask() as $mask143})144}145146fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {147let r = Simd::splat(*other);148apply_unary_kernel::<$width, $mask, _, _>(self, |l| {149Simd::from(*l).simd_lt(r).to_bitmask() as $mask150})151}152153fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {154let r = Simd::splat(*other);155apply_unary_kernel::<$width, $mask, _, _>(self, |l| {156Simd::from(*l).simd_le(r).to_bitmask() as $mask157})158}159160fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {161let r = Simd::splat(*other);162apply_unary_kernel::<$width, $mask, _, _>(self, |l| {163Simd::from(*l).simd_gt(r).to_bitmask() as $mask164})165}166167fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {168let r = Simd::splat(*other);169apply_unary_kernel::<$width, $mask, _, _>(self, |l| {170Simd::from(*l).simd_ge(r).to_bitmask() as $mask171})172}173}174};175}176177macro_rules! impl_float_total_ord_kernel {178($T: ty, $width: literal, $mask: ty) => {179impl TotalEqKernel for PrimitiveArray<$T> {180type Scalar = $T;181182fn tot_eq_kernel(&self, other: &Self) -> Bitmap {183apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {184let ls = Simd::from(*l);185let rs = Simd::from(*r);186let lhs_is_nan = ls.simd_ne(ls);187let rhs_is_nan = rs.simd_ne(rs);188((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs)).to_bitmask() as $mask189})190}191192fn tot_ne_kernel(&self, other: &Self) -> Bitmap {193apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {194let ls = Simd::from(*l);195let rs = Simd::from(*r);196let lhs_is_nan = ls.simd_ne(ls);197let rhs_is_nan = rs.simd_ne(rs);198(!((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs))).to_bitmask() as $mask199})200}201202fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {203let rs = Simd::splat(*other);204apply_unary_kernel::<$width, $mask, _, _>(self, |l| {205let ls = Simd::from(*l);206let lhs_is_nan = ls.simd_ne(ls);207let rhs_is_nan = rs.simd_ne(rs);208((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs)).to_bitmask() as $mask209})210}211212fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {213let rs = Simd::splat(*other);214apply_unary_kernel::<$width, $mask, _, _>(self, |l| {215let ls = Simd::from(*l);216let lhs_is_nan = ls.simd_ne(ls);217let rhs_is_nan = rs.simd_ne(rs);218(!((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs))).to_bitmask() as $mask219})220}221}222223impl TotalOrdKernel for PrimitiveArray<$T> {224type Scalar = $T;225226fn tot_lt_kernel(&self, other: &Self) -> Bitmap {227apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {228let ls = Simd::from(*l);229let rs = Simd::from(*r);230let lhs_is_nan = ls.simd_ne(ls);231(!(lhs_is_nan | ls.simd_ge(rs))).to_bitmask() as $mask232})233}234235fn tot_le_kernel(&self, other: &Self) -> Bitmap {236apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| {237let ls = Simd::from(*l);238let rs = Simd::from(*r);239let rhs_is_nan = rs.simd_ne(rs);240(rhs_is_nan | ls.simd_le(rs)).to_bitmask() as $mask241})242}243244fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {245let rs = Simd::splat(*other);246apply_unary_kernel::<$width, $mask, _, _>(self, |l| {247let ls = Simd::from(*l);248let lhs_is_nan = ls.simd_ne(ls);249(!(lhs_is_nan | ls.simd_ge(rs))).to_bitmask() as $mask250})251}252253fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {254let rs = Simd::splat(*other);255apply_unary_kernel::<$width, $mask, _, _>(self, |l| {256let ls = Simd::from(*l);257let rhs_is_nan = rs.simd_ne(rs);258(rhs_is_nan | ls.simd_le(rs)).to_bitmask() as $mask259})260}261262fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {263let rs = Simd::splat(*other);264apply_unary_kernel::<$width, $mask, _, _>(self, |l| {265let ls = Simd::from(*l);266let rhs_is_nan = rs.simd_ne(rs);267(!(rhs_is_nan | rs.simd_ge(ls))).to_bitmask() as $mask268})269}270271fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {272let rs = Simd::splat(*other);273apply_unary_kernel::<$width, $mask, _, _>(self, |l| {274let ls = Simd::from(*l);275let lhs_is_nan = ls.simd_ne(ls);276(lhs_is_nan | rs.simd_le(ls)).to_bitmask() as $mask277})278}279}280};281}282283impl_int_total_ord_kernel!(u8, 32, u32);284impl_int_total_ord_kernel!(u16, 16, u16);285impl_int_total_ord_kernel!(u32, 8, u8);286impl_int_total_ord_kernel!(u64, 8, u8);287impl_int_total_ord_kernel!(i8, 32, u32);288impl_int_total_ord_kernel!(i16, 16, u16);289impl_int_total_ord_kernel!(i32, 8, u8);290impl_int_total_ord_kernel!(i64, 8, u8);291impl_float_total_ord_kernel!(f32, 8, u8);292impl_float_total_ord_kernel!(f64, 8, u8);293294295