Path: blob/main/crates/polars-compute/src/comparisons/list.rs
6939 views
use arrow::array::{1Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,2ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray,3};4use arrow::bitmap::Bitmap;5use arrow::legacy::utils::CustomIterTools;6use arrow::types::{Offset, days_ms, f16, i256, months_days_ns};78use super::TotalEqKernel;910macro_rules! compare {11(12$lhs:expr, $rhs:expr,13$op:path, $true_op:expr,14$ineq_len_rv:literal, $invalid_rv:literal15) => {{16let lhs = $lhs;17let rhs = $rhs;1819assert_eq!(lhs.len(), rhs.len());20assert_eq!(lhs.dtype(), rhs.dtype());2122macro_rules! call_binary {23($T:ty) => {{24let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap();25let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap();2627(0..$lhs.len())28.map(|i| {29let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap());30let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap());3132if !lval || !rval {33return $invalid_rv;34}3536// SAFETY: ListArray's invariant offsets.len_proxy() == len37let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) };38let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) };3940if lend - lstart != rend - rstart {41return $ineq_len_rv;42}4344let mut lhs_values = lhs_values.clone();45lhs_values.slice(lstart, lend - lstart);46let mut rhs_values = rhs_values.clone();47rhs_values.slice(rstart, rend - rstart);4849$true_op($op(&lhs_values, &rhs_values))50})51.collect_trusted()52}};53}5455use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};56match lhs.values().dtype().to_physical_type() {57PH::Boolean => call_binary!(BooleanArray),58PH::BinaryView => call_binary!(BinaryViewArray),59PH::Utf8View => call_binary!(Utf8ViewArray),60PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),61PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),62PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),63PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),64PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),65PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),66PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),67PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),68PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),69PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),70PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),71PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),72PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),73PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),74PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),75PH::Primitive(PR::MonthDayNano) => {76call_binary!(PrimitiveArray<months_days_ns>)77},7879#[cfg(feature = "dtype-array")]80PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),81#[cfg(not(feature = "dtype-array"))]82PH::FixedSizeList => todo!(83"Comparison of FixedSizeListArray is not supported without dtype-array feature"84),8586PH::Null => call_binary!(NullArray),87PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),88PH::Binary => call_binary!(BinaryArray<i32>),89PH::LargeBinary => call_binary!(BinaryArray<i64>),90PH::Utf8 => call_binary!(Utf8Array<i32>),91PH::LargeUtf8 => call_binary!(Utf8Array<i64>),92PH::List => call_binary!(ListArray<i32>),93PH::LargeList => call_binary!(ListArray<i64>),94PH::Struct => call_binary!(StructArray),95PH::Union => todo!("Comparison of UnionArrays is not yet supported"),96PH::Map => todo!("Comparison of MapArrays is not yet supported"),97PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),98PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),99PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),100PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),101PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),102PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),103PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),104PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),105PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),106}107}};108}109110macro_rules! compare_broadcast {111(112$lhs:expr, $rhs:expr,113$offsets:expr, $validity:expr,114$op:path, $true_op:expr,115$ineq_len_rv:literal, $invalid_rv:literal116) => {{117let lhs = $lhs;118let rhs = $rhs;119120macro_rules! call_binary {121($T:ty) => {{122let values: &$T = $lhs.as_any().downcast_ref().unwrap();123let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();124125let length = $offsets.len_proxy();126127(0..length)128.map(move |i| {129let v = $validity.is_none_or(|v| v.get(i).unwrap());130131if !v {132return $invalid_rv;133}134135let (start, end) = unsafe { $offsets.start_end_unchecked(i) };136137if end - start != scalar.len() {138return $ineq_len_rv;139}140141// @TODO: I feel like there is a better way to do this.142let mut values: $T = values.clone();143<$T>::slice(&mut values, start, end - start);144145$true_op($op(&values, scalar))146})147.collect_trusted()148}};149}150151assert_eq!(lhs.dtype(), rhs.dtype());152153use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};154match lhs.dtype().to_physical_type() {155PH::Boolean => call_binary!(BooleanArray),156PH::BinaryView => call_binary!(BinaryViewArray),157PH::Utf8View => call_binary!(Utf8ViewArray),158PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),159PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),160PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),161PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),162PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),163PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),164PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),165PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),166PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),167PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),168PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),169PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),170PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),171PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),172PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),173PH::Primitive(PR::MonthDayNano) => {174call_binary!(PrimitiveArray<months_days_ns>)175},176177#[cfg(feature = "dtype-array")]178PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),179#[cfg(not(feature = "dtype-array"))]180PH::FixedSizeList => todo!(181"Comparison of FixedSizeListArray is not supported without dtype-array feature"182),183184PH::Null => call_binary!(NullArray),185PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),186PH::Binary => call_binary!(BinaryArray<i32>),187PH::LargeBinary => call_binary!(BinaryArray<i64>),188PH::Utf8 => call_binary!(Utf8Array<i32>),189PH::LargeUtf8 => call_binary!(Utf8Array<i64>),190PH::List => call_binary!(ListArray<i32>),191PH::LargeList => call_binary!(ListArray<i64>),192PH::Struct => call_binary!(StructArray),193PH::Union => todo!("Comparison of UnionArrays is not yet supported"),194PH::Map => todo!("Comparison of MapArrays is not yet supported"),195PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),196PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),197PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),198PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),199PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),200PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),201PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),202PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),203PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),204}205}};206}207208impl<O: Offset> TotalEqKernel for ListArray<O> {209type Scalar = Box<dyn Array>;210211fn tot_eq_kernel(&self, other: &Self) -> Bitmap {212compare!(213self,214other,215TotalEqKernel::tot_eq_missing_kernel,216|bm: Bitmap| bm.unset_bits() == 0,217false,218true219)220}221222fn tot_ne_kernel(&self, other: &Self) -> Bitmap {223compare!(224self,225other,226TotalEqKernel::tot_ne_missing_kernel,227|bm: Bitmap| bm.set_bits() > 0,228true,229false230)231}232233fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {234compare_broadcast!(235self.values().as_ref(),236other.as_ref(),237self.offsets(),238self.validity(),239TotalEqKernel::tot_eq_missing_kernel,240|bm: Bitmap| bm.unset_bits() == 0,241false,242true243)244}245246fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {247compare_broadcast!(248self.values().as_ref(),249other.as_ref(),250self.offsets(),251self.validity(),252TotalEqKernel::tot_ne_missing_kernel,253|bm: Bitmap| bm.set_bits() > 0,254true,255false256)257}258}259260261