Path: blob/main/crates/polars-ops/src/series/ops/is_in.rs
6939 views
use std::hash::Hash;12use arrow::array::BooleanArray;3use arrow::bitmap::BitmapBuilder;4use polars_core::prelude::arity::{unary_elementwise, unary_elementwise_values};5use polars_core::prelude::*;6use polars_core::{with_match_categorical_physical_type, with_match_physical_numeric_polars_type};7use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};89use self::row_encode::_get_rows_encoded_ca_unordered;1011fn is_in_helper_ca<'a, T>(12ca: &'a ChunkedArray<T>,13other: &'a ChunkedArray<T>,14nulls_equal: bool,15) -> PolarsResult<BooleanChunked>16where17T: PolarsDataType,18T::Physical<'a>: TotalHash + TotalEq + ToTotalOrd + Copy,19<T::Physical<'a> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,20{21let mut set = PlHashSet::with_capacity(other.len());22other.downcast_iter().for_each(|iter| {23iter.iter().for_each(|opt_val| {24if let Some(v) = opt_val {25set.insert(v.to_total_ord());26}27})28});2930if nulls_equal {31if other.has_nulls() {32// If the rhs has nulls, then nulls in the left set evaluates to true.33Ok(unary_elementwise(ca, |val| {34val.is_none_or(|v| set.contains(&v.to_total_ord()))35}))36} else {37// The rhs has no nulls; nulls in the left evaluates to false.38Ok(unary_elementwise(ca, |val| {39val.is_some_and(|v| set.contains(&v.to_total_ord()))40}))41}42} else {43Ok(44unary_elementwise_values(ca, |v| set.contains(&v.to_total_ord()))45.with_name(ca.name().clone()),46)47}48}4950fn is_in_helper_list_ca<'a, T>(51ca_in: &'a ChunkedArray<T>,52other: &'a ListChunked,53nulls_equal: bool,54) -> PolarsResult<BooleanChunked>55where56T: PolarsPhysicalType,57for<'b> T::Physical<'b>: TotalHash + TotalEq + ToTotalOrd + Copy,58for<'b> <T::Physical<'b> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,59{60let offsets = other.offsets()?;61let inner = other.get_inner();62let inner: &ChunkedArray<T> = inner.as_ref().as_ref();63let validity = other.rechunk_validity();6465let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 {66let value = ca_in.get(0);6768match value {69None if !nulls_equal => BooleanChunked::full_null(PlSmallStr::EMPTY, other.len()),70value => {71let mut builder = BitmapBuilder::with_capacity(other.len());7273for (start, length) in offsets.offset_and_length_iter() {74let mut is_in = false;75for i in 0..length {76is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord();77}78builder.push(is_in);79}8081let values = builder.freeze();8283let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);84BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])85},86}87} else {88assert_eq!(ca_in.len(), offsets.len_proxy());89{90if nulls_equal {91let mut builder = BitmapBuilder::with_capacity(ca_in.len());9293for (value, (start, length)) in ca_in.iter().zip(offsets.offset_and_length_iter()) {94let mut is_in = false;95for i in 0..length {96is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord();97}98builder.push(is_in);99}100101let values = builder.freeze();102103let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);104BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])105} else {106let mut builder = BitmapBuilder::with_capacity(ca_in.len());107108for (value, (start, length)) in ca_in.iter().zip(offsets.offset_and_length_iter()) {109let mut is_in = false;110if value.is_some() {111for i in 0..length {112is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord();113}114}115builder.push(is_in);116}117118let values = builder.freeze();119120let validity = match (validity, ca_in.rechunk_validity()) {121(None, None) => None,122(Some(v), None) | (None, Some(v)) => Some(v),123(Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)),124};125126let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);127BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])128}129}130};131ca.rename(ca_in.name().clone());132Ok(ca)133}134135#[cfg(feature = "dtype-array")]136fn is_in_helper_array_ca<'a, T>(137ca_in: &'a ChunkedArray<T>,138other: &'a ArrayChunked,139nulls_equal: bool,140) -> PolarsResult<BooleanChunked>141where142T: PolarsPhysicalType,143for<'b> T::Physical<'b>: TotalHash + TotalEq + ToTotalOrd + Copy,144for<'b> <T::Physical<'b> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,145{146let width = other.width();147let inner = other.get_inner();148let inner: &ChunkedArray<T> = inner.as_ref().as_ref();149let validity = other.rechunk_validity();150151let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 {152let value = ca_in.get(0);153154match value {155None if !nulls_equal => BooleanChunked::full_null(PlSmallStr::EMPTY, other.len()),156value => {157let mut builder = BitmapBuilder::with_capacity(other.len());158159for i in 0..other.len() {160let mut is_in = false;161for j in 0..width {162is_in |= value.to_total_ord() == inner.get(i * width + j).to_total_ord();163}164builder.push(is_in);165}166167let values = builder.freeze();168169let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);170BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])171},172}173} else {174assert_eq!(ca_in.len(), other.len());175{176if nulls_equal {177let mut builder = BitmapBuilder::with_capacity(ca_in.len());178179for (i, value) in ca_in.iter().enumerate() {180let mut is_in = false;181for j in 0..width {182is_in |= value.to_total_ord() == inner.get(i * width + j).to_total_ord();183}184builder.push(is_in);185}186187let values = builder.freeze();188189let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);190BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])191} else {192let mut builder = BitmapBuilder::with_capacity(ca_in.len());193194for (i, value) in ca_in.iter().enumerate() {195let mut is_in = false;196if value.is_some() {197for j in 0..width {198is_in |=199value.to_total_ord() == inner.get(i * width + j).to_total_ord();200}201}202builder.push(is_in);203}204205let values = builder.freeze();206207let validity = match (validity, ca_in.rechunk_validity()) {208(None, None) => None,209(Some(v), None) | (None, Some(v)) => Some(v),210(Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)),211};212213let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);214BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])215}216}217};218ca.rename(ca_in.name().clone());219Ok(ca)220}221222fn is_in_numeric<T>(223ca_in: &ChunkedArray<T>,224other: &Series,225nulls_equal: bool,226) -> PolarsResult<BooleanChunked>227where228T: PolarsNumericType,229T::Native: TotalHash + TotalEq + ToTotalOrd,230<T::Native as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,231{232match other.dtype() {233DataType::List(..) => {234let other = other.list()?;235if other.len() == 1 {236if other.has_nulls() {237return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));238}239240let other = other.explode(true)?;241let other = other.as_ref().as_ref();242is_in_helper_ca(ca_in, other, nulls_equal)243} else {244is_in_helper_list_ca(ca_in, other, nulls_equal)245}246},247#[cfg(feature = "dtype-array")]248DataType::Array(..) => {249let other = other.array()?;250if other.len() == 1 {251if other.has_nulls() {252return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));253}254255let other = other.explode(true)?;256let other = other.as_ref().as_ref();257is_in_helper_ca(ca_in, other, nulls_equal)258} else {259is_in_helper_array_ca(ca_in, other, nulls_equal)260}261},262_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),263}264}265266fn is_in_string(267ca_in: &StringChunked,268other: &Series,269nulls_equal: bool,270) -> PolarsResult<BooleanChunked> {271let other = match other.dtype() {272DataType::List(dt) if dt.is_string() || dt.is_enum() || dt.is_categorical() => {273let other = other.list()?;274other275.apply_to_inner(&|mut s| {276if dt.is_enum() || dt.is_categorical() {277s = s.cast(&DataType::String)?;278}279let s = s.str()?;280Ok(s.as_binary().into_series())281})?282.into_series()283},284#[cfg(feature = "dtype-array")]285DataType::Array(dt, _) if dt.is_string() || dt.is_enum() || dt.is_categorical() => {286let other = other.array()?;287other288.apply_to_inner(&|mut s| {289if dt.is_enum() || dt.is_categorical() {290s = s.cast(&DataType::String)?;291}292Ok(s.str()?.as_binary().into_series())293})?294.into_series()295},296_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),297};298is_in_binary(&ca_in.as_binary(), &other, nulls_equal)299}300301fn is_in_binary(302ca_in: &BinaryChunked,303other: &Series,304nulls_equal: bool,305) -> PolarsResult<BooleanChunked> {306match other.dtype() {307DataType::List(dt) if DataType::Binary == **dt => {308let other = other.list()?;309if other.len() == 1 {310if other.has_nulls() {311return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));312}313314let other = other.explode(true)?;315let other = other.binary()?;316is_in_helper_ca(ca_in, other, nulls_equal)317} else {318is_in_helper_list_ca(ca_in, other, nulls_equal)319}320},321#[cfg(feature = "dtype-array")]322DataType::Array(dt, _) if DataType::Binary == **dt => {323let other = other.array()?;324if other.len() == 1 {325if other.has_nulls() {326return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));327}328329let other = other.explode(true)?;330let other = other.binary()?;331is_in_helper_ca(ca_in, other, nulls_equal)332} else {333is_in_helper_array_ca(ca_in, other, nulls_equal)334}335},336_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),337}338}339340fn is_in_boolean(341ca_in: &BooleanChunked,342other: &Series,343nulls_equal: bool,344) -> PolarsResult<BooleanChunked> {345fn is_in_boolean_broadcast(346ca_in: &BooleanChunked,347other: &BooleanChunked,348nulls_equal: bool,349) -> PolarsResult<BooleanChunked> {350let has_true = other.any();351let nc = other.null_count();352353let has_false = if nc == 0 {354!other.all()355} else {356(other.sum().unwrap() as usize + nc) != other.len()357};358let value_map = |v| if v { has_true } else { has_false };359if nulls_equal {360if other.has_nulls() {361// If the rhs has nulls, then nulls in the left set evaluates to true.362Ok(ca_in.apply(|opt_v| Some(opt_v.is_none_or(value_map))))363} else {364// The rhs has no nulls; nulls in the left evaluates to false.365Ok(ca_in.apply(|opt_v| Some(opt_v.is_some_and(value_map))))366}367} else {368Ok(ca_in369.apply_values(value_map)370.with_name(ca_in.name().clone()))371}372}373374match other.dtype() {375DataType::List(dt) if ca_in.dtype() == &**dt => {376let other = other.list()?;377if other.len() == 1 {378if other.has_nulls() {379return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));380}381382let other = other.explode(true)?;383let other = other.bool()?;384is_in_boolean_broadcast(ca_in, other, nulls_equal)385} else {386is_in_helper_list_ca(ca_in, other, nulls_equal)387}388},389#[cfg(feature = "dtype-array")]390DataType::Array(dt, _) if ca_in.dtype() == &**dt => {391let other = other.array()?;392if other.len() == 1 {393if other.has_nulls() {394return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));395}396397let other = other.explode(true)?;398let other = other.bool()?;399is_in_boolean_broadcast(ca_in, other, nulls_equal)400} else {401is_in_helper_array_ca(ca_in, other, nulls_equal)402}403},404_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),405}406}407408#[cfg(feature = "dtype-categorical")]409fn is_in_cat_and_enum<T: PolarsCategoricalType>(410ca_in: &CategoricalChunked<T>,411other: &Series,412nulls_equal: bool,413) -> PolarsResult<BooleanChunked>414where415T::Native: ToTotalOrd<TotalOrdItem = T::Native>,416{417let to_categories = match (ca_in.dtype(), other.dtype().inner_dtype().unwrap()) {418(DataType::Enum(_, mapping) | DataType::Categorical(_, mapping), DataType::String) => {419(&|s: Series| {420let ca = s.str()?;421let ca: ChunkedArray<T::PolarsPhysical> = ca422.iter()423.flat_map(|opt_s| {424if let Some(s) = opt_s {425Some(mapping.get_cat(s).map(T::Native::from_cat))426} else {427Some(None)428}429})430.collect_ca(PlSmallStr::EMPTY);431Ok(ca.into_series())432}) as _433},434(DataType::Categorical(lcats, _), DataType::Categorical(rcats, _)) => {435ensure_same_categories(lcats, rcats)?;436(&|s: Series| Ok(s.cat::<T>()?.physical().clone().into_series())) as _437},438(DataType::Enum(lfcats, _), DataType::Enum(rfcats, _)) => {439ensure_same_frozen_categories(lfcats, rfcats)?;440(&|s: Series| Ok(s.cat::<T>()?.physical().clone().into_series())) as _441},442_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),443};444445let other = match other.dtype() {446DataType::List(_) => other.list()?.apply_to_inner(to_categories)?.into_series(),447#[cfg(feature = "dtype-array")]448DataType::Array(_, _) => other.array()?.apply_to_inner(to_categories)?.into_series(),449_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),450};451452is_in_numeric(ca_in.physical(), &other, nulls_equal)453}454455fn is_in_null(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult<BooleanChunked> {456if nulls_equal {457let ca_in = s.null()?;458Ok(match other.dtype() {459DataType::List(_) => {460let other = other.list()?;461if other.len() == 1 {462if other.has_nulls() {463return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));464}465466let other = other.explode(true)?;467BooleanChunked::from_iter_values(468ca_in.name().clone(),469std::iter::repeat_n(other.has_nulls(), ca_in.len()),470)471} else {472other.apply_amortized_generic(|opt_s| {473Some(opt_s.map(|s| s.as_ref().has_nulls()) == Some(true))474})475}476},477#[cfg(feature = "dtype-array")]478DataType::Array(_, _) => {479let other = other.array()?;480if other.len() == 1 {481if other.has_nulls() {482return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));483}484485let other = other.explode(true)?;486BooleanChunked::from_iter_values(487ca_in.name().clone(),488std::iter::repeat_n(other.has_nulls(), ca_in.len()),489)490} else {491other.apply_amortized_generic(|opt_s| {492Some(opt_s.map(|s| s.as_ref().has_nulls()) == Some(true))493})494}495},496_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),497})498} else {499let out = s.cast(&DataType::Boolean)?;500let ca_bool = out.bool()?.clone();501Ok(ca_bool)502}503}504505#[cfg(feature = "dtype-decimal")]506fn is_in_decimal(507ca_in: &DecimalChunked,508other: &Series,509nulls_equal: bool,510) -> PolarsResult<BooleanChunked> {511let Some(DataType::Decimal(_, other_scale)) = other.dtype().inner_dtype() else {512polars_bail!(opq = is_in, ca_in.dtype(), other.dtype());513};514let other_scale = other_scale.unwrap();515let scale = ca_in.scale().max(other_scale);516let ca_in = ca_in.to_scale(scale)?;517518match other.dtype() {519DataType::List(_) => {520let other = other.list()?;521let other = other.apply_to_inner(&|s| {522let s = s.decimal()?;523let s = s.to_scale(scale)?;524let s = s.physical();525Ok(s.to_owned().into_series())526})?;527let other = other.into_series();528is_in_numeric(ca_in.physical(), &other, nulls_equal)529},530#[cfg(feature = "dtype-array")]531DataType::Array(_, _) => {532let other = other.array()?;533let other = other.apply_to_inner(&|s| {534let s = s.decimal()?;535let s = s.to_scale(scale)?;536let s = s.physical();537Ok(s.to_owned().into_series())538})?;539let other = other.into_series();540is_in_numeric(ca_in.physical(), &other, nulls_equal)541},542_ => unreachable!(),543}544}545546fn is_in_row_encoded(547s: &Series,548other: &Series,549nulls_equal: bool,550) -> PolarsResult<BooleanChunked> {551let ca_in = _get_rows_encoded_ca_unordered(s.name().clone(), &[s.clone().into_column()])?;552let mut mask = match other.dtype() {553DataType::List(_) => {554let other = other.list()?;555let other = other.apply_to_inner(&|s| {556Ok(557_get_rows_encoded_ca_unordered(s.name().clone(), &[s.into_column()])?558.into_series(),559)560})?;561if other.len() == 1 {562if other.has_nulls() {563return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));564}565566let other = other.explode(true)?;567let other = other.binary_offset()?;568is_in_helper_ca(&ca_in, other, nulls_equal)569} else {570is_in_helper_list_ca(&ca_in, &other, nulls_equal)571}572},573#[cfg(feature = "dtype-array")]574DataType::Array(_, _) => {575let other = other.array()?;576let other = other.apply_to_inner(&|s| {577Ok(578_get_rows_encoded_ca_unordered(s.name().clone(), &[s.into_column()])?579.into_series(),580)581})?;582if other.len() == 1 {583if other.has_nulls() {584return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));585}586587let other = other.explode(true)?;588let other = other.binary_offset()?;589is_in_helper_ca(&ca_in, other, nulls_equal)590} else {591is_in_helper_array_ca(&ca_in, &other, nulls_equal)592}593},594_ => unreachable!(),595}?;596597let mut validity = other.rechunk_validity();598if !nulls_equal {599validity = match (validity, s.rechunk_validity()) {600(None, None) => None,601(Some(v), None) | (None, Some(v)) => Some(v),602(Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)),603};604}605606assert_eq!(mask.null_count(), 0);607mask.with_validities(&[validity]);608609Ok(mask)610}611612pub fn is_in(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult<BooleanChunked> {613polars_ensure!(614s.len() == other.len() || s.len() == 1 || other.len() == 1,615length_mismatch = "is_in",616s.len(),617other.len()618);619620#[allow(unused_mut)]621let mut other_is_valid_type = matches!(other.dtype(), DataType::List(_));622#[cfg(feature = "dtype-array")]623{624other_is_valid_type |= matches!(other.dtype(), DataType::Array(..))625}626polars_ensure!(other_is_valid_type, opq = is_in, s.dtype(), other.dtype());627628match s.dtype() {629#[cfg(feature = "dtype-categorical")]630dt @ DataType::Categorical(_, _) | dt @ DataType::Enum(_, _) => {631with_match_categorical_physical_type!(dt.cat_physical().unwrap(), |$C| {632is_in_cat_and_enum(s.cat::<$C>().unwrap(), other, nulls_equal)633})634},635DataType::String => {636let ca = s.str().unwrap();637is_in_string(ca, other, nulls_equal)638},639DataType::Binary => {640let ca = s.binary().unwrap();641is_in_binary(ca, other, nulls_equal)642},643DataType::Boolean => {644let ca = s.bool().unwrap();645is_in_boolean(ca, other, nulls_equal)646},647DataType::Null => is_in_null(s, other, nulls_equal),648#[cfg(feature = "dtype-decimal")]649DataType::Decimal(_, _) => {650let ca_in = s.decimal()?;651is_in_decimal(ca_in, other, nulls_equal)652},653dt if dt.is_nested() => is_in_row_encoded(s, other, nulls_equal),654dt if dt.to_physical().is_primitive_numeric() => {655let s = s.to_physical_repr();656let other = other.to_physical_repr();657let other = other.as_ref();658with_match_physical_numeric_polars_type!(s.dtype(), |$T| {659let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();660is_in_numeric(ca, other, nulls_equal)661})662},663dt => polars_bail!(opq = is_in, dt),664}665}666667668