Path: blob/main/crates/polars-compute/src/comparisons/list.rs
8431 views
use arrow::array::{1Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,2ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray,3};4use arrow::bitmap::Bitmap;5use arrow::legacy::utils::CustomIterTools;6use arrow::types::{Offset, days_ms, i256, months_days_ns};7use polars_utils::float16::pf16;89use super::TotalEqKernel;1011macro_rules! compare {12(13$lhs:expr, $rhs:expr,14$op:path, $true_op:expr,15$ineq_len_rv:literal, $invalid_rv:literal16) => {{17let lhs = $lhs;18let rhs = $rhs;1920assert_eq!(lhs.len(), rhs.len());21assert_eq!(lhs.dtype(), rhs.dtype());2223macro_rules! call_binary {24($T:ty) => {{25let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap();26let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap();2728(0..$lhs.len())29.map(|i| {30let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap());31let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap());3233if !lval || !rval {34return $invalid_rv;35}3637// SAFETY: ListArray's invariant offsets.len_proxy() == len38let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) };39let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) };4041if lend - lstart != rend - rstart {42return $ineq_len_rv;43}4445let mut lhs_values = lhs_values.clone();46lhs_values.slice(lstart, lend - lstart);47let mut rhs_values = rhs_values.clone();48rhs_values.slice(rstart, rend - rstart);4950$true_op($op(&lhs_values, &rhs_values))51})52.collect_trusted()53}};54}5556use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};57match lhs.values().dtype().to_physical_type() {58PH::Boolean => call_binary!(BooleanArray),59PH::BinaryView => call_binary!(BinaryViewArray),60PH::Utf8View => call_binary!(Utf8ViewArray),61PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),62PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),63PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),64PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),65PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),66PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),67PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),68PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),69PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),70PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),71PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),72PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),73PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),74PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),75PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),76PH::Primitive(PR::MonthDayNano) => {77call_binary!(PrimitiveArray<months_days_ns>)78},79PH::Primitive(PR::MonthDayMillis) => unimplemented!(),8081#[cfg(feature = "dtype-array")]82PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),83#[cfg(not(feature = "dtype-array"))]84PH::FixedSizeList => todo!(85"Comparison of FixedSizeListArray is not supported without dtype-array feature"86),8788PH::Null => call_binary!(NullArray),89PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),90PH::Binary => call_binary!(BinaryArray<i32>),91PH::LargeBinary => call_binary!(BinaryArray<i64>),92PH::Utf8 => call_binary!(Utf8Array<i32>),93PH::LargeUtf8 => call_binary!(Utf8Array<i64>),94PH::List => call_binary!(ListArray<i32>),95PH::LargeList => call_binary!(ListArray<i64>),96PH::Struct => call_binary!(StructArray),97PH::Union => todo!("Comparison of UnionArrays is not yet supported"),98PH::Map => todo!("Comparison of MapArrays is not yet supported"),99PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),100PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),101PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),102PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),103PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),104PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),105PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),106PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),107PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),108PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),109}110}};111}112113macro_rules! compare_broadcast {114(115$lhs:expr, $rhs:expr,116$offsets:expr, $validity:expr,117$op:path, $true_op:expr,118$ineq_len_rv:literal, $invalid_rv:literal119) => {{120let lhs = $lhs;121let rhs = $rhs;122123macro_rules! call_binary {124($T:ty) => {{125let values: &$T = $lhs.as_any().downcast_ref().unwrap();126let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();127128let length = $offsets.len_proxy();129130(0..length)131.map(move |i| {132let v = $validity.is_none_or(|v| v.get(i).unwrap());133134if !v {135return $invalid_rv;136}137138let (start, end) = unsafe { $offsets.start_end_unchecked(i) };139140if end - start != scalar.len() {141return $ineq_len_rv;142}143144// @TODO: I feel like there is a better way to do this.145let mut values: $T = values.clone();146<$T>::slice(&mut values, start, end - start);147148$true_op($op(&values, scalar))149})150.collect_trusted()151}};152}153154assert_eq!(lhs.dtype(), rhs.dtype());155156use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};157match lhs.dtype().to_physical_type() {158PH::Boolean => call_binary!(BooleanArray),159PH::BinaryView => call_binary!(BinaryViewArray),160PH::Utf8View => call_binary!(Utf8ViewArray),161PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),162PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),163PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),164PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),165PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),166PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),167PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),168PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),169PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),170PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),171PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),172PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),173PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),174PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),175PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),176PH::Primitive(PR::MonthDayNano) => {177call_binary!(PrimitiveArray<months_days_ns>)178},179PH::Primitive(PR::MonthDayMillis) => unimplemented!(),180181#[cfg(feature = "dtype-array")]182PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),183#[cfg(not(feature = "dtype-array"))]184PH::FixedSizeList => todo!(185"Comparison of FixedSizeListArray is not supported without dtype-array feature"186),187188PH::Null => call_binary!(NullArray),189PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),190PH::Binary => call_binary!(BinaryArray<i32>),191PH::LargeBinary => call_binary!(BinaryArray<i64>),192PH::Utf8 => call_binary!(Utf8Array<i32>),193PH::LargeUtf8 => call_binary!(Utf8Array<i64>),194PH::List => call_binary!(ListArray<i32>),195PH::LargeList => call_binary!(ListArray<i64>),196PH::Struct => call_binary!(StructArray),197PH::Union => todo!("Comparison of UnionArrays is not yet supported"),198PH::Map => todo!("Comparison of MapArrays is not yet supported"),199PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),200PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),201PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),202PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),203PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),204PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),205PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),206PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),207PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),208PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),209}210}};211}212213impl<O: Offset> TotalEqKernel for ListArray<O> {214type Scalar = Box<dyn Array>;215216fn tot_eq_kernel(&self, other: &Self) -> Bitmap {217compare!(218self,219other,220TotalEqKernel::tot_eq_missing_kernel,221|bm: Bitmap| bm.unset_bits() == 0,222false,223true224)225}226227fn tot_ne_kernel(&self, other: &Self) -> Bitmap {228compare!(229self,230other,231TotalEqKernel::tot_ne_missing_kernel,232|bm: Bitmap| bm.set_bits() > 0,233true,234false235)236}237238fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {239compare_broadcast!(240self.values().as_ref(),241other.as_ref(),242self.offsets(),243self.validity(),244TotalEqKernel::tot_eq_missing_kernel,245|bm: Bitmap| bm.unset_bits() == 0,246false,247true248)249}250251fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {252compare_broadcast!(253self.values().as_ref(),254other.as_ref(),255self.offsets(),256self.validity(),257TotalEqKernel::tot_ne_missing_kernel,258|bm: Bitmap| bm.set_bits() > 0,259true,260false261)262}263}264265266