Path: blob/main/crates/polars-compute/src/if_then_else/view.rs
6939 views
use std::mem::MaybeUninit;1use std::ops::Deref;2use std::sync::Arc;34use arrow::array::{Array, BinaryViewArray, MutablePlBinary, Utf8ViewArray, View};5use arrow::bitmap::Bitmap;6use arrow::buffer::Buffer;7use arrow::datatypes::ArrowDataType;8use polars_utils::aliases::{InitHashMaps, PlHashSet};910use super::IfThenElseKernel;11use crate::if_then_else::scalar::if_then_else_broadcast_both_scalar_64;1213// Makes a buffer and a set of views into that buffer from a set of strings.14// Does not allocate a buffer if not necessary.15fn make_buffer_and_views<const N: usize>(16strings: [&[u8]; N],17buffer_idx: u32,18) -> ([View; N], Option<Buffer<u8>>) {19let mut buf_data = Vec::new();20let views = strings.map(|s| {21let offset = buf_data.len().try_into().unwrap();22if s.len() > 12 {23buf_data.extend(s);24}25View::new_from_bytes(s, buffer_idx, offset)26});27let buf = (!buf_data.is_empty()).then(|| buf_data.into());28(views, buf)29}3031fn has_duplicate_buffers(bufs: &[Buffer<u8>]) -> bool {32let mut has_duplicate_buffers = false;33let mut bufset = PlHashSet::new();34for buf in bufs {35if !bufset.insert(buf.as_ptr()) {36has_duplicate_buffers = true;37break;38}39}40has_duplicate_buffers41}4243impl IfThenElseKernel for BinaryViewArray {44type Scalar<'a> = &'a [u8];4546fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {47let combined_buffers: Arc<_>;48let false_buffer_idx_offset: u32;49let mut has_duplicate_bufs = false;50if Arc::ptr_eq(if_true.data_buffers(), if_false.data_buffers()) {51// Share exact same buffers, no need to combine.52combined_buffers = if_true.data_buffers().clone();53false_buffer_idx_offset = 0;54} else {55// Put false buffers after true buffers.56let true_buffers = if_true.data_buffers().iter().cloned();57let false_buffers = if_false.data_buffers().iter().cloned();5859combined_buffers = true_buffers.chain(false_buffers).collect();60has_duplicate_bufs = has_duplicate_buffers(&combined_buffers);61false_buffer_idx_offset = if_true.data_buffers().len() as u32;62}6364let views = super::if_then_else_loop(65mask,66if_true.views(),67if_false.views(),68|m, t, f, o| if_then_else_view_rest(m, t, f, o, false_buffer_idx_offset),69|m, t, f, o| if_then_else_view_64(m, t, f, o, false_buffer_idx_offset),70);7172let validity = super::if_then_else_validity(mask, if_true.validity(), if_false.validity());7374let mut builder = MutablePlBinary::with_capacity(views.len());7576if has_duplicate_bufs {77unsafe {78builder.extend_non_null_views_unchecked_dedupe(79views.into_iter(),80combined_buffers.deref(),81)82};83} else {84unsafe {85builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())86};87}88builder89.freeze_with_dtype(if_true.dtype().clone())90.with_validity(validity)91}9293fn if_then_else_broadcast_true(94mask: &Bitmap,95if_true: Self::Scalar<'_>,96if_false: &Self,97) -> Self {98// It's cheaper if we put the false buffers first, that way we don't need to modify any views in the loop.99let false_buffers = if_false.data_buffers().iter().cloned();100let true_buffer_idx_offset: u32 = if_false.data_buffers().len() as u32;101let ([true_view], true_buffer) = make_buffer_and_views([if_true], true_buffer_idx_offset);102let combined_buffers: Arc<_> = false_buffers.chain(true_buffer).collect();103104let views = super::if_then_else_loop_broadcast_false(105true, // Invert the mask so we effectively broadcast true.106mask,107if_false.views(),108true_view,109if_then_else_broadcast_false_view_64,110);111112let validity = super::if_then_else_validity(mask, None, if_false.validity());113114let mut builder = MutablePlBinary::with_capacity(views.len());115116unsafe {117if has_duplicate_buffers(&combined_buffers) {118builder.extend_non_null_views_unchecked_dedupe(119views.into_iter(),120combined_buffers.deref(),121)122} else {123builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())124}125}126builder127.freeze_with_dtype(if_false.dtype().clone())128.with_validity(validity)129}130131fn if_then_else_broadcast_false(132mask: &Bitmap,133if_true: &Self,134if_false: Self::Scalar<'_>,135) -> Self {136// It's cheaper if we put the true buffers first, that way we don't need to modify any views in the loop.137let true_buffers = if_true.data_buffers().iter().cloned();138let false_buffer_idx_offset: u32 = if_true.data_buffers().len() as u32;139let ([false_view], false_buffer) =140make_buffer_and_views([if_false], false_buffer_idx_offset);141let combined_buffers: Arc<_> = true_buffers.chain(false_buffer).collect();142143let views = super::if_then_else_loop_broadcast_false(144false,145mask,146if_true.views(),147false_view,148if_then_else_broadcast_false_view_64,149);150151let validity = super::if_then_else_validity(mask, if_true.validity(), None);152153let mut builder = MutablePlBinary::with_capacity(views.len());154unsafe {155if has_duplicate_buffers(&combined_buffers) {156builder.extend_non_null_views_unchecked_dedupe(157views.into_iter(),158combined_buffers.deref(),159)160} else {161builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref())162}163};164builder165.freeze_with_dtype(if_true.dtype().clone())166.with_validity(validity)167}168169fn if_then_else_broadcast_both(170dtype: ArrowDataType,171mask: &Bitmap,172if_true: Self::Scalar<'_>,173if_false: Self::Scalar<'_>,174) -> Self {175let ([true_view, false_view], buffer) = make_buffer_and_views([if_true, if_false], 0);176let buffers: Arc<_> = buffer.into_iter().collect();177let views = super::if_then_else_loop_broadcast_both(178mask,179true_view,180false_view,181if_then_else_broadcast_both_scalar_64,182);183184let mut builder = MutablePlBinary::with_capacity(views.len());185unsafe {186if has_duplicate_buffers(&buffers) {187builder.extend_non_null_views_unchecked_dedupe(views.into_iter(), buffers.deref())188} else {189builder.extend_non_null_views_unchecked(views.into_iter(), buffers.deref())190}191};192builder.freeze_with_dtype(dtype)193}194}195196impl IfThenElseKernel for Utf8ViewArray {197type Scalar<'a> = &'a str;198199fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self {200let ret =201IfThenElseKernel::if_then_else(mask, &if_true.to_binview(), &if_false.to_binview());202unsafe { ret.to_utf8view_unchecked() }203}204205fn if_then_else_broadcast_true(206mask: &Bitmap,207if_true: Self::Scalar<'_>,208if_false: &Self,209) -> Self {210let ret = IfThenElseKernel::if_then_else_broadcast_true(211mask,212if_true.as_bytes(),213&if_false.to_binview(),214);215unsafe { ret.to_utf8view_unchecked() }216}217218fn if_then_else_broadcast_false(219mask: &Bitmap,220if_true: &Self,221if_false: Self::Scalar<'_>,222) -> Self {223let ret = IfThenElseKernel::if_then_else_broadcast_false(224mask,225&if_true.to_binview(),226if_false.as_bytes(),227);228unsafe { ret.to_utf8view_unchecked() }229}230231fn if_then_else_broadcast_both(232dtype: ArrowDataType,233mask: &Bitmap,234if_true: Self::Scalar<'_>,235if_false: Self::Scalar<'_>,236) -> Self {237let ret: BinaryViewArray = IfThenElseKernel::if_then_else_broadcast_both(238dtype,239mask,240if_true.as_bytes(),241if_false.as_bytes(),242);243unsafe { ret.to_utf8view_unchecked() }244}245}246247pub fn if_then_else_view_rest(248mask: u64,249if_true: &[View],250if_false: &[View],251out: &mut [MaybeUninit<View>],252false_buffer_idx_offset: u32,253) {254assert!(if_true.len() <= out.len()); // Removes bounds checks in inner loop.255let true_it = if_true.iter();256let false_it = if_false.iter();257for (i, (t, f)) in true_it.zip(false_it).enumerate() {258// Written like this, this loop *should* be branchless.259// Unfortunately we're still dependent on the compiler.260let m = (mask >> i) & 1 != 0;261let src = if m { t } else { f };262let mut v = *src;263let offset = if m | (v.length <= 12) {264// Yes, | instead of || is intentional.2650266} else {267false_buffer_idx_offset268};269v.buffer_idx += offset;270out[i] = MaybeUninit::new(v);271}272}273274pub fn if_then_else_view_64(275mask: u64,276if_true: &[View; 64],277if_false: &[View; 64],278out: &mut [MaybeUninit<View>; 64],279false_buffer_idx_offset: u32,280) {281if_then_else_view_rest(mask, if_true, if_false, out, false_buffer_idx_offset)282}283284// Using the scalar variant of this works, but was slower, we want to select a source pointer and285// then copy it. Using this version for the integers results in branches.286pub fn if_then_else_broadcast_false_view_64(287mask: u64,288if_true: &[View; 64],289if_false: View,290out: &mut [MaybeUninit<View>; 64],291) {292assert!(if_true.len() == out.len()); // Removes bounds checks in inner loop.293for (i, t) in if_true.iter().enumerate() {294let src = if (mask >> i) & 1 != 0 { t } else { &if_false };295out[i] = MaybeUninit::new(*src);296}297}298299300