Path: blob/main/crates/polars-ops/src/chunked_array/list/sets.rs
6939 views
use std::fmt::{Display, Formatter};1use std::hash::Hash;23use arrow::array::{4Array, BinaryViewArray, ListArray, MutableArray, MutablePlBinary, MutablePrimitiveArray,5PrimitiveArray, Utf8ViewArray,6};7use arrow::bitmap::Bitmap;8use arrow::compute::utils::combine_validities_and;9use arrow::offset::OffsetsBuffer;10use arrow::types::NativeType;11use polars_core::prelude::*;12use polars_core::with_match_physical_numeric_type;13use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap};14#[cfg(feature = "serde")]15use serde::{Deserialize, Serialize};1617trait MaterializeValues<K> {18// extends the iterator to the values and returns the current offset19fn extend_buf<I: Iterator<Item = K>>(&mut self, values: I) -> usize;20}2122impl<T> MaterializeValues<Option<T>> for MutablePrimitiveArray<T>23where24T: NativeType,25{26fn extend_buf<I: Iterator<Item = Option<T>>>(&mut self, values: I) -> usize {27self.extend(values);28self.len()29}30}3132impl<T> MaterializeValues<TotalOrdWrap<Option<T>>> for MutablePrimitiveArray<T>33where34T: NativeType,35{36fn extend_buf<I: Iterator<Item = TotalOrdWrap<Option<T>>>>(&mut self, values: I) -> usize {37self.extend(values.map(|x| x.0));38self.len()39}40}4142impl<'a> MaterializeValues<Option<&'a [u8]>> for MutablePlBinary {43fn extend_buf<I: Iterator<Item = Option<&'a [u8]>>>(&mut self, values: I) -> usize {44self.extend(values);45self.len()46}47}4849fn set_operation<K, I, J, R>(50set: &mut PlIndexSet<K>,51set2: &mut PlIndexSet<K>,52a: I,53b: J,54out: &mut R,55set_op: SetOperation,56broadcast_rhs: bool,57) -> usize58where59K: Eq + Hash + Copy,60I: IntoIterator<Item = K>,61J: IntoIterator<Item = K>,62R: MaterializeValues<K>,63{64set.clear();65let a = a.into_iter();66let b = b.into_iter();6768match set_op {69SetOperation::Intersection => {70set.extend(a);71// If broadcast `set2` should already be filled.72if !broadcast_rhs {73set2.clear();74set2.extend(b);75}76out.extend_buf(set.intersection(set2).copied())77},78SetOperation::Union => {79set.extend(a);80set.extend(b);81out.extend_buf(set.drain(..))82},83SetOperation::Difference => {84set.extend(a);85for v in b {86set.swap_remove(&v);87}88out.extend_buf(set.drain(..))89},90SetOperation::SymmetricDifference => {91// If broadcast `set2` should already be filled.92if !broadcast_rhs {93set2.clear();94set2.extend(b);95}96// We could speed this up, but implementing ourselves, but we need to have a cloneable97// iterator as we need 2 passes98set.extend(a);99out.extend_buf(set.symmetric_difference(set2).copied())100},101}102}103104fn copied_wrapper_opt<T: Copy + TotalEq + TotalHash>(105v: Option<&T>,106) -> <Option<T> as ToTotalOrd>::TotalOrdItem {107v.copied().to_total_ord()108}109110#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]111#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]112#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]113pub enum SetOperation {114Intersection,115Union,116Difference,117SymmetricDifference,118}119120impl Display for SetOperation {121fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {122let s = match self {123SetOperation::Intersection => "intersection",124SetOperation::Union => "union",125SetOperation::Difference => "difference",126SetOperation::SymmetricDifference => "symmetric_difference",127};128write!(f, "{s}")129}130}131132fn primitive<T>(133a: &PrimitiveArray<T>,134b: &PrimitiveArray<T>,135offsets_a: &[i64],136offsets_b: &[i64],137set_op: SetOperation,138validity: Option<Bitmap>,139) -> PolarsResult<ListArray<i64>>140where141T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd,142<Option<T> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,143{144let broadcast_lhs = offsets_a.len() == 2;145let broadcast_rhs = offsets_b.len() == 2;146147let mut set = Default::default();148let mut set2: PlIndexSet<<Option<T> as ToTotalOrd>::TotalOrdItem> = Default::default();149150let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max(151*offsets_a.last().unwrap(),152*offsets_b.last().unwrap(),153) as usize);154let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));155offsets.push(0i64);156157let offsets_slice = if offsets_a.len() > offsets_b.len() {158offsets_a159} else {160offsets_b161};162let first_a = offsets_a[0];163let second_a = offsets_a[1];164let first_b = offsets_b[0];165let second_b = offsets_b[1];166if broadcast_rhs {167set2.extend(168b.into_iter()169.skip(first_b as usize)170.take(second_b as usize - first_b as usize)171.map(copied_wrapper_opt),172);173}174for i in 1..offsets_slice.len() {175// If we go OOB we take the first element as we are then broadcasting.176let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;177let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;178179let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;180let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;181182// The branches are the same every loop.183// We rely on branch prediction here.184let offset = if broadcast_rhs {185// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount186let a_iter = a187.into_iter()188.skip(start_a)189.take(end_a - start_a)190.map(copied_wrapper_opt);191let b_iter = b192.into_iter()193.skip(first_b as usize)194.take(second_b as usize - first_b as usize)195.map(copied_wrapper_opt);196set_operation(197&mut set,198&mut set2,199a_iter,200b_iter,201&mut values_out,202set_op,203true,204)205} else if broadcast_lhs {206let a_iter = a207.into_iter()208.skip(first_a as usize)209.take(second_a as usize - first_a as usize)210.map(copied_wrapper_opt);211212let b_iter = b213.into_iter()214.skip(start_b)215.take(end_b - start_b)216.map(copied_wrapper_opt);217218set_operation(219&mut set,220&mut set2,221a_iter,222b_iter,223&mut values_out,224set_op,225false,226)227} else {228// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount229let a_iter = a230.into_iter()231.skip(start_a)232.take(end_a - start_a)233.map(copied_wrapper_opt);234235let b_iter = b236.into_iter()237.skip(start_b)238.take(end_b - start_b)239.map(copied_wrapper_opt);240set_operation(241&mut set,242&mut set2,243a_iter,244b_iter,245&mut values_out,246set_op,247false,248)249};250251offsets.push(offset as i64);252}253let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };254let dtype = ListArray::<i64>::default_datatype(values_out.dtype().clone());255256let values: PrimitiveArray<T> = values_out.into();257Ok(ListArray::new(dtype, offsets, values.boxed(), validity))258}259260fn binary(261a: &BinaryViewArray,262b: &BinaryViewArray,263offsets_a: &[i64],264offsets_b: &[i64],265set_op: SetOperation,266validity: Option<Bitmap>,267as_utf8: bool,268) -> PolarsResult<ListArray<i64>> {269let broadcast_lhs = offsets_a.len() == 2;270let broadcast_rhs = offsets_b.len() == 2;271let mut set = Default::default();272let mut set2: PlIndexSet<Option<&[u8]>> = Default::default();273274let mut values_out = MutablePlBinary::with_capacity(std::cmp::max(275*offsets_a.last().unwrap(),276*offsets_b.last().unwrap(),277) as usize);278let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));279offsets.push(0i64);280281let offsets_slice = if offsets_a.len() > offsets_b.len() {282offsets_a283} else {284offsets_b285};286let first_a = offsets_a[0];287let second_a = offsets_a[1];288let first_b = offsets_b[0];289let second_b = offsets_b[1];290291if broadcast_rhs {292// set2.extend(b_iter)293set2.extend(294b.into_iter()295.skip(first_b as usize)296.take(second_b as usize - first_b as usize),297);298}299300for i in 1..offsets_slice.len() {301// If we go OOB we take the first element as we are then broadcasting.302let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;303let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;304305let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;306let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;307308// The branches are the same every loop.309// We rely on branch prediction here.310let offset = if broadcast_rhs {311// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount312let a_iter = a.into_iter().skip(start_a).take(end_a - start_a);313let b_iter = b314.into_iter()315.skip(first_b as usize)316.take(second_b as usize - first_b as usize);317set_operation(318&mut set,319&mut set2,320a_iter,321b_iter,322&mut values_out,323set_op,324true,325)326} else if broadcast_lhs {327let a_iter = a328.into_iter()329.skip(first_a as usize)330.take(second_a as usize - first_a as usize);331let b_iter = b.into_iter().skip(start_b).take(end_b - start_b);332set_operation(333&mut set,334&mut set2,335a_iter,336b_iter,337&mut values_out,338set_op,339false,340)341} else {342// going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount343let a_iter = a.into_iter().skip(start_a).take(end_a - start_a);344let b_iter = b.into_iter().skip(start_b).take(end_b - start_b);345set_operation(346&mut set,347&mut set2,348a_iter,349b_iter,350&mut values_out,351set_op,352false,353)354};355offsets.push(offset as i64);356}357let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };358let values = values_out.freeze();359360if as_utf8 {361let values = unsafe { values.to_utf8view_unchecked() };362let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());363Ok(ListArray::new(dtype, offsets, values.boxed(), validity))364} else {365let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());366Ok(ListArray::new(dtype, offsets, values.boxed(), validity))367}368}369370fn array_set_operation(371a: &ListArray<i64>,372b: &ListArray<i64>,373set_op: SetOperation,374) -> PolarsResult<ListArray<i64>> {375let offsets_a = a.offsets().as_slice();376let offsets_b = b.offsets().as_slice();377378let values_a = a.values();379let values_b = b.values();380assert_eq!(values_a.dtype(), values_b.dtype());381382let dtype = values_b.dtype();383let validity = combine_validities_and(a.validity(), b.validity());384385match dtype {386ArrowDataType::Utf8View => {387let a = values_a388.as_any()389.downcast_ref::<Utf8ViewArray>()390.unwrap()391.to_binview();392let b = values_b393.as_any()394.downcast_ref::<Utf8ViewArray>()395.unwrap()396.to_binview();397398binary(&a, &b, offsets_a, offsets_b, set_op, validity, true)399},400ArrowDataType::BinaryView => {401let a = values_a.as_any().downcast_ref::<BinaryViewArray>().unwrap();402let b = values_b.as_any().downcast_ref::<BinaryViewArray>().unwrap();403binary(a, b, offsets_a, offsets_b, set_op, validity, false)404},405ArrowDataType::Boolean => {406polars_bail!(InvalidOperation: "boolean type not yet supported in list 'set' operations")407},408_ => {409with_match_physical_numeric_type!(DataType::from_arrow_dtype(dtype), |$T| {410let a = values_a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();411let b = values_b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();412413primitive(&a, &b, offsets_a, offsets_b, set_op, validity)414})415},416}417}418419pub fn list_set_operation(420a: &ListChunked,421b: &ListChunked,422set_op: SetOperation,423) -> PolarsResult<ListChunked> {424polars_ensure!(a.len() == b.len() || b.len() == 1 || a.len() == 1, ShapeMismatch: "column lengths don't match");425polars_ensure!(a.dtype() == b.dtype(), InvalidOperation: "cannot do 'set' operation on dtypes: {} and {}", a.dtype(), b.dtype());426let mut a = a.clone();427let mut b = b.clone();428if a.len() != b.len() {429a.rechunk_mut();430b.rechunk_mut();431}432433// We will OOB in the kernel otherwise.434a.prune_empty_chunks();435b.prune_empty_chunks();436437// we use the unsafe variant because we want to keep the nested logical types type.438unsafe {439arity::try_binary_unchecked_same_type(440&a,441&b,442|a, b| array_set_operation(a, b, set_op).map(|arr| arr.boxed()),443false,444false,445)446}447}448449450