Path: blob/main/crates/polars-ops/src/chunked_array/list/sets.rs
8411 views
use std::fmt::{Display, Formatter};1use std::hash::Hash;23use arrow::array::{4Array, BinaryViewArray, ListArray, MutableArray, MutablePlBinary, MutablePrimitiveArray,5PrimitiveArray, Utf8ViewArray,6};7use arrow::bitmap::Bitmap;8use arrow::compute::utils::combine_validities_and;9use arrow::offset::OffsetsBuffer;10use arrow::types::NativeType;11use polars_core::prelude::*;12use polars_core::with_match_physical_numeric_type;13use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap};14#[cfg(feature = "serde")]15use serde::{Deserialize, Serialize};1617trait MaterializeValues<K> {18// extends the iterator to the values and returns the current offset19fn extend_buf<I: Iterator<Item = K>>(&mut self, values: I) -> usize;20}2122impl<T> MaterializeValues<Option<T>> for MutablePrimitiveArray<T>23where24T: NativeType,25{26fn extend_buf<I: Iterator<Item = Option<T>>>(&mut self, values: I) -> usize {27self.extend(values);28self.len()29}30}3132impl<T> MaterializeValues<TotalOrdWrap<Option<T>>> for MutablePrimitiveArray<T>33where34T: NativeType,35{36fn extend_buf<I: Iterator<Item = TotalOrdWrap<Option<T>>>>(&mut self, values: I) -> usize {37self.extend(values.map(|x| x.0));38self.len()39}40}4142impl<'a> MaterializeValues<Option<&'a [u8]>> for MutablePlBinary {43fn extend_buf<I: Iterator<Item = Option<&'a [u8]>>>(&mut self, values: I) -> usize {44self.extend(values);45self.len()46}47}4849#[allow(clippy::too_many_arguments)]50fn set_operation<I, J, K, R>(51set: &mut PlIndexSet<K>,52set2: &mut PlIndexSet<K>,53a: &mut I,54b: &mut J,55out: &mut R,56set_op: SetOperation,57broadcast_rhs: bool,58) -> usize59where60K: Eq + Hash + Copy,61I: Iterator<Item = K>,62J: Iterator<Item = K>,63R: MaterializeValues<K>,64{65set.clear();6667match set_op {68SetOperation::Intersection => {69set.extend(a);70// If broadcast `set2` should already be filled.71if !broadcast_rhs {72set2.clear();73set2.extend(b);74}75out.extend_buf(set.intersection(set2).copied())76},77SetOperation::Union => {78set.extend(a);79set.extend(b);80out.extend_buf(set.drain(..))81},82SetOperation::Difference => {83set.extend(a);84for v in b {85set.swap_remove(&v);86}87out.extend_buf(set.drain(..))88},89SetOperation::SymmetricDifference => {90// If broadcast `set2` should already be filled.91if !broadcast_rhs {92set2.clear();93set2.extend(b);94}95// We could speed this up, but implementing ourselves, but we need to have a cloneable96// iterator as we need 2 passes97set.extend(a);98out.extend_buf(set.symmetric_difference(set2).copied())99},100}101}102103fn copied_wrapper_opt<T: Copy + TotalEq + TotalHash>(104v: Option<&T>,105) -> <Option<T> as ToTotalOrd>::TotalOrdItem {106v.copied().to_total_ord()107}108109#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]110#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]111#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]112pub enum SetOperation {113Intersection,114Union,115Difference,116SymmetricDifference,117}118119impl Display for SetOperation {120fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {121let s = match self {122SetOperation::Intersection => "intersection",123SetOperation::Union => "union",124SetOperation::Difference => "difference",125SetOperation::SymmetricDifference => "symmetric_difference",126};127write!(f, "{s}")128}129}130131fn primitive<T>(132a: &PrimitiveArray<T>,133b: &PrimitiveArray<T>,134offsets_a: &[i64],135offsets_b: &[i64],136set_op: SetOperation,137validity: Option<Bitmap>,138) -> PolarsResult<ListArray<i64>>139where140T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd,141<Option<T> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,142{143let broadcast_lhs = offsets_a.len() == 2;144let broadcast_rhs = offsets_b.len() == 2;145146let mut set = Default::default();147let mut set2: PlIndexSet<<Option<T> as ToTotalOrd>::TotalOrdItem> = Default::default();148149let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max(150*offsets_a.last().unwrap(),151*offsets_b.last().unwrap(),152) as usize);153let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));154offsets.push(0i64);155156let offsets_slice = if offsets_a.len() > offsets_b.len() {157offsets_a158} else {159offsets_b160};161let first_a = offsets_a[0];162let second_a = offsets_a[1];163let first_b = offsets_b[0];164let second_b = offsets_b[1];165if broadcast_rhs {166set2.extend(167b.into_iter()168.skip(first_b as usize)169.take(second_b as usize - first_b as usize)170.map(copied_wrapper_opt),171);172}173174let mut iter_a = a.into_iter().skip(first_a as usize);175let mut iter_b = b.into_iter().skip(first_b as usize);176177for i in 1..offsets_slice.len() {178// If we go OOB we take the first element as we are then broadcasting.179let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;180let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;181182let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;183let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;184185let mut iter_a_broadcast = iter_a.clone();186let mut iter_b_broadcast = iter_b.clone();187188// The branches are the same every loop.189// We rely on branch prediction here.190let mut iter_a = if broadcast_lhs {191iter_a_broadcast192.by_ref()193.take(second_a as usize - first_a as usize)194.map(copied_wrapper_opt)195} else {196iter_a197.by_ref()198.take(end_a - start_a)199.map(copied_wrapper_opt)200};201let mut iter_b = if broadcast_rhs {202iter_b_broadcast203.by_ref()204.take(second_b as usize - first_b as usize)205.map(copied_wrapper_opt)206} else {207iter_b208.by_ref()209.take(end_b - start_b)210.map(copied_wrapper_opt)211};212213let offset = set_operation(214&mut set,215&mut set2,216&mut iter_a,217&mut iter_b,218&mut values_out,219set_op,220broadcast_rhs,221);222223assert!(iter_a.next().is_none());224if !broadcast_rhs || matches!(set_op, SetOperation::Union | SetOperation::Difference) {225assert!(iter_b.next().is_none());226};227228offsets.push(offset as i64);229}230let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };231let dtype = ListArray::<i64>::default_datatype(values_out.dtype().clone());232233let values: PrimitiveArray<T> = values_out.into();234Ok(ListArray::new(dtype, offsets, values.boxed(), validity))235}236237fn binary(238a: &BinaryViewArray,239b: &BinaryViewArray,240offsets_a: &[i64],241offsets_b: &[i64],242set_op: SetOperation,243validity: Option<Bitmap>,244as_utf8: bool,245) -> PolarsResult<ListArray<i64>> {246let broadcast_lhs = offsets_a.len() == 2;247let broadcast_rhs = offsets_b.len() == 2;248let mut set: PlIndexSet<Option<&[u8]>> = Default::default();249let mut set2: PlIndexSet<Option<&[u8]>> = Default::default();250251let mut values_out = MutablePlBinary::with_capacity(std::cmp::max(252*offsets_a.last().unwrap(),253*offsets_b.last().unwrap(),254) as usize);255let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));256offsets.push(0i64);257258let offsets_slice = if offsets_a.len() > offsets_b.len() {259offsets_a260} else {261offsets_b262};263let first_a = offsets_a[0];264let second_a = offsets_a[1];265let first_b = offsets_b[0];266let second_b = offsets_b[1];267268if broadcast_rhs {269// set2.extend(b_iter)270set2.extend(271b.into_iter()272.skip(first_b as usize)273.take(second_b as usize - first_b as usize),274);275}276277let mut iter_a = a.into_iter().skip(first_a as usize);278let mut iter_b = b.into_iter().skip(first_b as usize);279280for i in 1..offsets_slice.len() {281// If we go OOB we take the first element as we are then broadcasting.282let start_a = *offsets_a.get(i - 1).unwrap_or(&first_a) as usize;283let end_a = *offsets_a.get(i).unwrap_or(&second_a) as usize;284285let start_b = *offsets_b.get(i - 1).unwrap_or(&first_b) as usize;286let end_b = *offsets_b.get(i).unwrap_or(&second_b) as usize;287288let mut iter_a_broadcast = iter_a.clone();289let mut iter_b_broadcast = iter_b.clone();290291// The branches are the same every loop.292// We rely on branch prediction here.293let mut iter_a = if broadcast_lhs {294iter_a_broadcast295.by_ref()296.take(second_a as usize - first_a as usize)297} else {298iter_a.by_ref().take(end_a - start_a)299};300let mut iter_b = if broadcast_rhs {301iter_b_broadcast302.by_ref()303.take(second_b as usize - first_b as usize)304} else {305iter_b.by_ref().take(end_b - start_b)306};307308let offset = set_operation(309&mut set,310&mut set2,311&mut iter_a,312&mut iter_b,313&mut values_out,314set_op,315broadcast_rhs,316);317318assert!(iter_a.next().is_none());319if !broadcast_rhs || matches!(set_op, SetOperation::Union | SetOperation::Difference) {320assert!(iter_b.next().is_none());321};322323offsets.push(offset as i64);324}325let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };326let values = values_out.freeze();327328if as_utf8 {329let values = unsafe { values.to_utf8view_unchecked() };330let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());331Ok(ListArray::new(dtype, offsets, values.boxed(), validity))332} else {333let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());334Ok(ListArray::new(dtype, offsets, values.boxed(), validity))335}336}337338fn array_set_operation(339a: &ListArray<i64>,340b: &ListArray<i64>,341set_op: SetOperation,342) -> PolarsResult<ListArray<i64>> {343let offsets_a = a.offsets().as_slice();344let offsets_b = b.offsets().as_slice();345346let values_a = a.values();347let values_b = b.values();348assert_eq!(values_a.dtype(), values_b.dtype());349350let dtype = values_b.dtype();351let validity = combine_validities_and(a.validity(), b.validity());352353match dtype {354ArrowDataType::Utf8View => {355let a = values_a356.as_any()357.downcast_ref::<Utf8ViewArray>()358.unwrap()359.to_binview();360let b = values_b361.as_any()362.downcast_ref::<Utf8ViewArray>()363.unwrap()364.to_binview();365366binary(&a, &b, offsets_a, offsets_b, set_op, validity, true)367},368ArrowDataType::BinaryView => {369let a = values_a.as_any().downcast_ref::<BinaryViewArray>().unwrap();370let b = values_b.as_any().downcast_ref::<BinaryViewArray>().unwrap();371binary(a, b, offsets_a, offsets_b, set_op, validity, false)372},373ArrowDataType::Boolean => {374polars_bail!(InvalidOperation: "boolean type not yet supported in list 'set' operations")375},376_ => {377with_match_physical_numeric_type!(DataType::from_arrow_dtype(dtype), |$T| {378let a = values_a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();379let b = values_b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();380381primitive(&a, &b, offsets_a, offsets_b, set_op, validity)382})383},384}385}386387pub fn list_set_operation(388a: &ListChunked,389b: &ListChunked,390set_op: SetOperation,391) -> PolarsResult<ListChunked> {392polars_ensure!(a.len() == b.len() || b.len() == 1 || a.len() == 1, ShapeMismatch: "column lengths don't match");393polars_ensure!(a.dtype() == b.dtype(), InvalidOperation: "cannot do 'set' operation on dtypes: {} and {}", a.dtype(), b.dtype());394let mut a = a.clone();395let mut b = b.clone();396if a.len() != b.len() {397a.rechunk_mut();398b.rechunk_mut();399}400401// We will OOB in the kernel otherwise.402a.prune_empty_chunks();403b.prune_empty_chunks();404405// we use the unsafe variant because we want to keep the nested logical types type.406unsafe {407arity::try_binary_unchecked_same_type(408&a,409&b,410|a, b| array_set_operation(a, b, set_op).map(|arr| arr.boxed()),411false,412false,413)414}415}416417418