Path: blob/main/crates/polars-compute/src/if_then_else/mod.rs
6939 views
use std::mem::MaybeUninit;12use arrow::array::{Array, PrimitiveArray};3use arrow::bitmap::utils::SlicesIterator;4use arrow::bitmap::{self, Bitmap};5use arrow::datatypes::ArrowDataType;67use crate::NotSimdPrimitive;89mod array;10mod boolean;11mod list;12mod scalar;13#[cfg(feature = "simd")]14mod simd;15mod view;1617pub trait IfThenElseKernel: Sized + Array {18type Scalar<'a>;1920fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self;21fn if_then_else_broadcast_true(22mask: &Bitmap,23if_true: Self::Scalar<'_>,24if_false: &Self,25) -> Self;26fn if_then_else_broadcast_false(27mask: &Bitmap,28if_true: &Self,29if_false: Self::Scalar<'_>,30) -> Self;31fn if_then_else_broadcast_both(32dtype: ArrowDataType,33mask: &Bitmap,34if_true: Self::Scalar<'_>,35if_false: Self::Scalar<'_>,36) -> Self;37}3839impl<T: NotSimdPrimitive> IfThenElseKernel for PrimitiveArray<T> {40type Scalar<'a> = T;4142fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {43let values = if_then_else_loop(44mask,45if_true.values(),46if_false.values(),47scalar::if_then_else_scalar_rest,48scalar::if_then_else_scalar_64,49);50let validity = if_then_else_validity(mask, if_true.validity(), if_false.validity());51PrimitiveArray::from_vec(values).with_validity(validity)52}5354fn if_then_else_broadcast_true(55mask: &Bitmap,56if_true: Self::Scalar<'_>,57if_false: &Self,58) -> Self {59let values = if_then_else_loop_broadcast_false(60true,61mask,62if_false.values(),63if_true,64scalar::if_then_else_broadcast_false_scalar_64,65);66let validity = if_then_else_validity(mask, None, if_false.validity());67PrimitiveArray::from_vec(values).with_validity(validity)68}6970fn if_then_else_broadcast_false(71mask: &Bitmap,72if_true: &Self,73if_false: Self::Scalar<'_>,74) -> Self {75let values = if_then_else_loop_broadcast_false(76false,77mask,78if_true.values(),79if_false,80scalar::if_then_else_broadcast_false_scalar_64,81);82let validity = if_then_else_validity(mask, if_true.validity(), None);83PrimitiveArray::from_vec(values).with_validity(validity)84}8586fn if_then_else_broadcast_both(87_dtype: ArrowDataType,88mask: &Bitmap,89if_true: Self::Scalar<'_>,90if_false: Self::Scalar<'_>,91) -> Self {92let values = if_then_else_loop_broadcast_both(93mask,94if_true,95if_false,96scalar::if_then_else_broadcast_both_scalar_64,97);98PrimitiveArray::from_vec(values)99}100}101102pub fn if_then_else_validity(103mask: &Bitmap,104if_true: Option<&Bitmap>,105if_false: Option<&Bitmap>,106) -> Option<Bitmap> {107match (if_true, if_false) {108(None, None) => None,109(None, Some(f)) => Some(mask | f),110(Some(t), None) => Some(bitmap::binary(mask, t, |m, t| !m | t)),111(Some(t), Some(f)) => Some(bitmap::ternary(mask, t, f, |m, t, f| (m & t) | (!m & f))),112}113}114115fn if_then_else_extend<B, ET: Fn(&mut B, usize, usize), EF: Fn(&mut B, usize, usize)>(116builder: &mut B,117mask: &Bitmap,118extend_true: ET,119extend_false: EF,120) {121let mut last_true_end = 0;122for (start, len) in SlicesIterator::new(mask) {123if start != last_true_end {124extend_false(builder, last_true_end, start - last_true_end);125};126extend_true(builder, start, len);127last_true_end = start + len;128}129if last_true_end != mask.len() {130extend_false(builder, last_true_end, mask.len() - last_true_end)131}132}133134fn if_then_else_loop<T, F, F64>(135mask: &Bitmap,136if_true: &[T],137if_false: &[T],138process_var: F,139process_chunk: F64,140) -> Vec<T>141where142T: Copy,143F: Fn(u64, &[T], &[T], &mut [MaybeUninit<T>]),144F64: Fn(u64, &[T; 64], &[T; 64], &mut [MaybeUninit<T>; 64]),145{146assert_eq!(mask.len(), if_true.len());147assert_eq!(mask.len(), if_false.len());148149let mut ret = Vec::with_capacity(mask.len());150let out = &mut ret.spare_capacity_mut()[..mask.len()];151152// Handle prefix.153let aligned = mask.aligned::<u64>();154let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen());155let (start_false, rest_false) = if_false.split_at(aligned.prefix_bitlen());156let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());157if aligned.prefix_bitlen() > 0 {158process_var(aligned.prefix(), start_true, start_false, start_out);159}160161// Handle bulk.162let mut true_chunks = rest_true.chunks_exact(64);163let mut false_chunks = rest_false.chunks_exact(64);164let mut out_chunks = rest_out.chunks_exact_mut(64);165let combined = true_chunks166.by_ref()167.zip(false_chunks.by_ref())168.zip(out_chunks.by_ref());169for (i, ((tc, fc), oc)) in combined.enumerate() {170let m = unsafe { *aligned.bulk().get_unchecked(i) };171process_chunk(172m,173tc.try_into().unwrap(),174fc.try_into().unwrap(),175oc.try_into().unwrap(),176);177}178179// Handle suffix.180if aligned.suffix_bitlen() > 0 {181process_var(182aligned.suffix(),183true_chunks.remainder(),184false_chunks.remainder(),185out_chunks.into_remainder(),186);187}188189unsafe {190ret.set_len(mask.len());191}192ret193}194195fn if_then_else_loop_broadcast_false<T, F64>(196invert_mask: bool, // Allows code reuse for both false and true broadcasts.197mask: &Bitmap,198if_true: &[T],199if_false: T,200process_chunk: F64,201) -> Vec<T>202where203T: Copy,204F64: Fn(u64, &[T; 64], T, &mut [MaybeUninit<T>; 64]),205{206assert_eq!(mask.len(), if_true.len());207208let mut ret = Vec::with_capacity(mask.len());209let out = &mut ret.spare_capacity_mut()[..mask.len()];210211// XOR with all 1's inverts the mask.212let xor_inverter = if invert_mask { u64::MAX } else { 0 };213214// Handle prefix.215let aligned = mask.aligned::<u64>();216let (start_true, rest_true) = if_true.split_at(aligned.prefix_bitlen());217let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());218if aligned.prefix_bitlen() > 0 {219scalar::if_then_else_broadcast_false_scalar_rest(220aligned.prefix() ^ xor_inverter,221start_true,222if_false,223start_out,224);225}226227// Handle bulk.228let mut true_chunks = rest_true.chunks_exact(64);229let mut out_chunks = rest_out.chunks_exact_mut(64);230let combined = true_chunks.by_ref().zip(out_chunks.by_ref());231for (i, (tc, oc)) in combined.enumerate() {232let m = unsafe { *aligned.bulk().get_unchecked(i) } ^ xor_inverter;233process_chunk(m, tc.try_into().unwrap(), if_false, oc.try_into().unwrap());234}235236// Handle suffix.237if aligned.suffix_bitlen() > 0 {238scalar::if_then_else_broadcast_false_scalar_rest(239aligned.suffix() ^ xor_inverter,240true_chunks.remainder(),241if_false,242out_chunks.into_remainder(),243);244}245246unsafe {247ret.set_len(mask.len());248}249ret250}251252fn if_then_else_loop_broadcast_both<T, F64>(253mask: &Bitmap,254if_true: T,255if_false: T,256generate_chunk: F64,257) -> Vec<T>258where259T: Copy,260F64: Fn(u64, T, T, &mut [MaybeUninit<T>; 64]),261{262let mut ret = Vec::with_capacity(mask.len());263let out = &mut ret.spare_capacity_mut()[..mask.len()];264265// Handle prefix.266let aligned = mask.aligned::<u64>();267let (start_out, rest_out) = out.split_at_mut(aligned.prefix_bitlen());268scalar::if_then_else_broadcast_both_scalar_rest(aligned.prefix(), if_true, if_false, start_out);269270// Handle bulk.271let mut out_chunks = rest_out.chunks_exact_mut(64);272for (i, oc) in out_chunks.by_ref().enumerate() {273let m = unsafe { *aligned.bulk().get_unchecked(i) };274generate_chunk(m, if_true, if_false, oc.try_into().unwrap());275}276277// Handle suffix.278if aligned.suffix_bitlen() > 0 {279scalar::if_then_else_broadcast_both_scalar_rest(280aligned.suffix(),281if_true,282if_false,283out_chunks.into_remainder(),284);285}286287unsafe {288ret.set_len(mask.len());289}290ret291}292293294