Path: blob/main/crates/polars-compute/src/comparisons/array.rs
6939 views
use arrow::array::{1Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,2FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array,3Utf8ViewArray,4};5use arrow::bitmap::Bitmap;6use arrow::bitmap::utils::count_zeros;7use arrow::datatypes::ArrowDataType;8use arrow::legacy::utils::CustomIterTools;9use arrow::types::{days_ms, f16, i256, months_days_ns};1011use super::TotalEqKernel;12use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel};1314/// Condenses a bitmap of n * width elements into one with n elements.15///16/// For each block of width bits a zero count is done. The block of bits is then17/// replaced with a single bit: the result of true_zero_count(zero_count).18fn agg_array_bitmap<F>(bm: Bitmap, width: usize, true_zero_count: F) -> Bitmap19where20F: Fn(usize) -> bool,21{22if bm.len() == 1 {23bm24} else {25assert!(width > 0 && bm.len().is_multiple_of(width));2627let (slice, offset, _len) = bm.as_slice();28(0..bm.len() / width)29.map(|i| true_zero_count(count_zeros(slice, offset + i * width, width)))30.collect()31}32}3334impl TotalEqKernel for FixedSizeListArray {35type Scalar = Box<dyn Array>;3637fn tot_eq_kernel(&self, other: &Self) -> Bitmap {38// Nested comparison always done with eq_missing, propagating doesn't39// make any sense.4041assert_eq!(self.len(), other.len());42let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_logical_type()43else {44panic!("array comparison called with non-array type");45};46let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_logical_type()47else {48panic!("array comparison called with non-array type");49};50assert_eq!(self_type.dtype(), other_type.dtype());5152if self_width != other_width {53return Bitmap::new_with_value(false, self.len());54}5556if *self_width == 0 {57return Bitmap::new_with_value(true, self.len());58}5960// @TODO: It is probably worth it to dispatch to a special kernel for when there are61// several nested arrays because that can be rather slow with this code.62let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref());6364agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0)65}6667fn tot_ne_kernel(&self, other: &Self) -> Bitmap {68assert_eq!(self.len(), other.len());69let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_logical_type()70else {71panic!("array comparison called with non-array type");72};73let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_logical_type()74else {75panic!("array comparison called with non-array type");76};77assert_eq!(self_type.dtype(), other_type.dtype());7879if self_width != other_width {80return Bitmap::new_with_value(true, self.len());81}8283if *self_width == 0 {84return Bitmap::new_with_value(false, self.len());85}8687// @TODO: It is probably worth it to dispatch to a special kernel for when there are88// several nested arrays because that can be rather slow with this code.89let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref());9091agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size())92}9394fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {95let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_logical_type() else {96panic!("array comparison called with non-array type");97};98assert_eq!(self_type.dtype(), other.dtype().to_logical_type());99100let width = *width;101102if width != other.len() {103return Bitmap::new_with_value(false, self.len());104}105106if width == 0 {107return Bitmap::new_with_value(true, self.len());108}109110// @TODO: It is probably worth it to dispatch to a special kernel for when there are111// several nested arrays because that can be rather slow with this code.112array_fsl_tot_eq_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width)113}114115fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {116let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_logical_type() else {117panic!("array comparison called with non-array type");118};119assert_eq!(self_type.dtype(), other.dtype().to_logical_type());120121let width = *width;122123if width != other.len() {124return Bitmap::new_with_value(true, self.len());125}126127if width == 0 {128return Bitmap::new_with_value(false, self.len());129}130131// @TODO: It is probably worth it to dispatch to a special kernel for when there are132// several nested arrays because that can be rather slow with this code.133array_fsl_tot_ne_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width)134}135}136137macro_rules! compare {138($lhs:expr, $rhs:expr, $length:expr, $width:expr, $op:path, $true_op:expr) => {{139let lhs = $lhs;140let rhs = $rhs;141142macro_rules! call_binary {143($T:ty) => {{144let values: &$T = $lhs.as_any().downcast_ref().unwrap();145let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();146147(0..$length)148.map(move |i| {149// @TODO: I feel like there is a better way to do this.150let mut values: $T = values.clone();151<$T>::slice(&mut values, i * $width, $width);152153$true_op($op(&values, scalar))154})155.collect_trusted()156}};157}158159assert_eq!(lhs.dtype(), rhs.dtype());160161use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};162match lhs.dtype().to_physical_type() {163PH::Boolean => call_binary!(BooleanArray),164PH::BinaryView => call_binary!(BinaryViewArray),165PH::Utf8View => call_binary!(Utf8ViewArray),166PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),167PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),168PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),169PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),170PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),171PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),172PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),173PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),174PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),175PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),176PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),177PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),178PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),179PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),180PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),181PH::Primitive(PR::MonthDayNano) => {182call_binary!(PrimitiveArray<months_days_ns>)183},184185#[cfg(feature = "dtype-array")]186PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),187#[cfg(not(feature = "dtype-array"))]188PH::FixedSizeList => todo!(189"Comparison of FixedSizeListArray is not supported without dtype-array feature"190),191192PH::Null => call_binary!(NullArray),193PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),194PH::Binary => call_binary!(BinaryArray<i32>),195PH::LargeBinary => call_binary!(BinaryArray<i64>),196PH::Utf8 => call_binary!(Utf8Array<i32>),197PH::LargeUtf8 => call_binary!(Utf8Array<i64>),198PH::List => call_binary!(ListArray<i32>),199PH::LargeList => call_binary!(ListArray<i64>),200PH::Struct => call_binary!(StructArray),201PH::Union => todo!("Comparison of UnionArrays is not yet supported"),202PH::Map => todo!("Comparison of MapArrays is not yet supported"),203PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),204PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),205PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),206PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),207PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),208PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),209PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),210PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),211PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),212}213}};214}215216fn array_fsl_tot_eq_missing_kernel(217values: &dyn Array,218scalar: &dyn Array,219length: usize,220width: usize,221) -> Bitmap {222// @NOTE: Zero-Width Array are handled before223debug_assert_eq!(values.len(), length * width);224debug_assert_eq!(scalar.len(), width);225226compare!(227values,228scalar,229length,230width,231TotalEqKernel::tot_eq_missing_kernel,232|bm: Bitmap| bm.unset_bits() == 0233)234}235236fn array_fsl_tot_ne_missing_kernel(237values: &dyn Array,238scalar: &dyn Array,239length: usize,240width: usize,241) -> Bitmap {242// @NOTE: Zero-Width Array are handled before243debug_assert_eq!(values.len(), length * width);244debug_assert_eq!(scalar.len(), width);245246compare!(247values,248scalar,249length,250width,251TotalEqKernel::tot_ne_missing_kernel,252|bm: Bitmap| bm.set_bits() > 0253)254}255256257