Path: blob/main/crates/polars-compute/src/comparisons/array.rs
8398 views
use arrow::array::{1Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,2FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array,3Utf8ViewArray,4};5use arrow::bitmap::Bitmap;6use arrow::bitmap::utils::count_zeros;7use arrow::datatypes::ArrowDataType;8use arrow::legacy::utils::CustomIterTools;9use arrow::types::{days_ms, i256, months_days_ns};10use polars_utils::float16::pf16;1112use super::TotalEqKernel;13use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel};1415/// Condenses a bitmap of n * width elements into one with n elements.16///17/// For each block of width bits a zero count is done. The block of bits is then18/// replaced with a single bit: the result of true_zero_count(zero_count).19fn agg_array_bitmap<F>(bm: Bitmap, width: usize, true_zero_count: F) -> Bitmap20where21F: Fn(usize) -> bool,22{23if bm.len() == 1 {24bm25} else {26assert!(width > 0 && bm.len().is_multiple_of(width));2728let (slice, offset, _len) = bm.as_slice();29(0..bm.len() / width)30.map(|i| true_zero_count(count_zeros(slice, offset + i * width, width)))31.collect()32}33}3435impl TotalEqKernel for FixedSizeListArray {36type Scalar = Box<dyn Array>;3738fn tot_eq_kernel(&self, other: &Self) -> Bitmap {39// Nested comparison always done with eq_missing, propagating doesn't40// make any sense.4142assert_eq!(self.len(), other.len());43let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_storage() else {44panic!("array comparison called with non-array type");45};46let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_storage()47else {48panic!("array comparison called with non-array type");49};50assert_eq!(self_type.dtype(), other_type.dtype());5152if self_width != other_width {53return Bitmap::new_with_value(false, self.len());54}5556if *self_width == 0 {57return Bitmap::new_with_value(true, self.len());58}5960// @TODO: It is probably worth it to dispatch to a special kernel for when there are61// several nested arrays because that can be rather slow with this code.62let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref());6364agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0)65}6667fn tot_ne_kernel(&self, other: &Self) -> Bitmap {68assert_eq!(self.len(), other.len());69let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_storage() else {70panic!("array comparison called with non-array type");71};72let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_storage()73else {74panic!("array comparison called with non-array type");75};76assert_eq!(self_type.dtype(), other_type.dtype());7778if self_width != other_width {79return Bitmap::new_with_value(true, self.len());80}8182if *self_width == 0 {83return Bitmap::new_with_value(false, self.len());84}8586// @TODO: It is probably worth it to dispatch to a special kernel for when there are87// several nested arrays because that can be rather slow with this code.88let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref());8990agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size())91}9293fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {94let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_storage() else {95panic!("array comparison called with non-array type");96};97assert_eq!(self_type.dtype(), other.dtype().to_storage());9899let width = *width;100101if width != other.len() {102return Bitmap::new_with_value(false, self.len());103}104105if width == 0 {106return Bitmap::new_with_value(true, self.len());107}108109// @TODO: It is probably worth it to dispatch to a special kernel for when there are110// several nested arrays because that can be rather slow with this code.111array_fsl_tot_eq_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width)112}113114fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {115let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_storage() else {116panic!("array comparison called with non-array type");117};118assert_eq!(self_type.dtype(), other.dtype().to_storage());119120let width = *width;121122if width != other.len() {123return Bitmap::new_with_value(true, self.len());124}125126if width == 0 {127return Bitmap::new_with_value(false, self.len());128}129130// @TODO: It is probably worth it to dispatch to a special kernel for when there are131// several nested arrays because that can be rather slow with this code.132array_fsl_tot_ne_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width)133}134}135136macro_rules! compare {137($lhs:expr, $rhs:expr, $length:expr, $width:expr, $op:path, $true_op:expr) => {{138let lhs = $lhs;139let rhs = $rhs;140141macro_rules! call_binary {142($T:ty) => {{143let values: &$T = $lhs.as_any().downcast_ref().unwrap();144let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();145146(0..$length)147.map(move |i| {148// @TODO: I feel like there is a better way to do this.149let mut values: $T = values.clone();150<$T>::slice(&mut values, i * $width, $width);151152$true_op($op(&values, scalar))153})154.collect_trusted()155}};156}157158assert_eq!(lhs.dtype(), rhs.dtype());159160use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};161match lhs.dtype().to_physical_type() {162PH::Boolean => call_binary!(BooleanArray),163PH::BinaryView => call_binary!(BinaryViewArray),164PH::Utf8View => call_binary!(Utf8ViewArray),165PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),166PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),167PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),168PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),169PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),170PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),171PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),172PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),173PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),174PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),175PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),176PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),177PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),178PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),179PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),180PH::Primitive(PR::MonthDayNano) => {181call_binary!(PrimitiveArray<months_days_ns>)182},183PH::Primitive(PR::MonthDayMillis) => unimplemented!(),184185#[cfg(feature = "dtype-array")]186PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),187#[cfg(not(feature = "dtype-array"))]188PH::FixedSizeList => todo!(189"Comparison of FixedSizeListArray is not supported without dtype-array feature"190),191192PH::Null => call_binary!(NullArray),193PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),194PH::Binary => call_binary!(BinaryArray<i32>),195PH::LargeBinary => call_binary!(BinaryArray<i64>),196PH::Utf8 => call_binary!(Utf8Array<i32>),197PH::LargeUtf8 => call_binary!(Utf8Array<i64>),198PH::List => call_binary!(ListArray<i32>),199PH::LargeList => call_binary!(ListArray<i64>),200PH::Struct => call_binary!(StructArray),201PH::Union => todo!("Comparison of UnionArrays is not yet supported"),202PH::Map => todo!("Comparison of MapArrays is not yet supported"),203PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),204PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),205PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),206PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),207PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),208PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),209PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),210PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),211PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),212PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),213}214}};215}216217fn array_fsl_tot_eq_missing_kernel(218values: &dyn Array,219scalar: &dyn Array,220length: usize,221width: usize,222) -> Bitmap {223// @NOTE: Zero-Width Array are handled before224debug_assert_eq!(values.len(), length * width);225debug_assert_eq!(scalar.len(), width);226227compare!(228values,229scalar,230length,231width,232TotalEqKernel::tot_eq_missing_kernel,233|bm: Bitmap| bm.unset_bits() == 0234)235}236237fn array_fsl_tot_ne_missing_kernel(238values: &dyn Array,239scalar: &dyn Array,240length: usize,241width: usize,242) -> Bitmap {243// @NOTE: Zero-Width Array are handled before244debug_assert_eq!(values.len(), length * width);245debug_assert_eq!(scalar.len(), width);246247compare!(248values,249scalar,250length,251width,252TotalEqKernel::tot_ne_missing_kernel,253|bm: Bitmap| bm.set_bits() > 0254)255}256257258