Path: blob/main/crates/polars-ops/src/series/ops/is_in.rs
8483 views
use std::hash::Hash;12use arrow::array::BooleanArray;3use arrow::bitmap::BitmapBuilder;4use polars_core::prelude::arity::{unary_elementwise, unary_elementwise_values};5use polars_core::prelude::*;6use polars_core::{with_match_categorical_physical_type, with_match_physical_numeric_polars_type};7use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};89use self::row_encode::_get_rows_encoded_ca_unordered;1011fn is_in_helper_ca<'a, T>(12ca: &'a ChunkedArray<T>,13other: &'a ChunkedArray<T>,14nulls_equal: bool,15) -> PolarsResult<BooleanChunked>16where17T: PolarsDataType,18T::Physical<'a>: TotalHash + TotalEq + ToTotalOrd + Copy,19<T::Physical<'a> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,20{21let mut set = PlHashSet::with_capacity(other.len());22other.downcast_iter().for_each(|iter| {23iter.iter().for_each(|opt_val| {24if let Some(v) = opt_val {25set.insert(v.to_total_ord());26}27})28});2930if nulls_equal {31if other.has_nulls() {32// If the rhs has nulls, then nulls in the left set evaluates to true.33Ok(unary_elementwise(ca, |val| {34val.is_none_or(|v| set.contains(&v.to_total_ord()))35}))36} else {37// The rhs has no nulls; nulls in the left evaluates to false.38Ok(unary_elementwise(ca, |val| {39val.is_some_and(|v| set.contains(&v.to_total_ord()))40}))41}42} else {43Ok(44unary_elementwise_values(ca, |v| set.contains(&v.to_total_ord()))45.with_name(ca.name().clone()),46)47}48}4950fn is_in_helper_list_ca<'a, T>(51ca_in: &'a ChunkedArray<T>,52other: &'a ListChunked,53nulls_equal: bool,54) -> PolarsResult<BooleanChunked>55where56T: PolarsPhysicalType,57for<'b> T::Physical<'b>: TotalHash + TotalEq + ToTotalOrd + Copy,58for<'b> <T::Physical<'b> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,59{60let offsets = other.offsets()?;61let inner = other.get_inner();62let inner: &ChunkedArray<T> = inner.as_ref().as_ref();63let validity = other.rechunk_validity();6465let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 {66let value = ca_in.get(0);6768match value {69None if !nulls_equal => BooleanChunked::full_null(PlSmallStr::EMPTY, other.len()),70value => {71let mut builder = BitmapBuilder::with_capacity(other.len());7273for (start, length) in offsets.offset_and_length_iter() {74let mut is_in = false;75for i in 0..length {76is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord();77}78builder.push(is_in);79}8081let values = builder.freeze();8283let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);84BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])85},86}87} else {88assert_eq!(ca_in.len(), offsets.len_proxy());89{90if nulls_equal {91let mut builder = BitmapBuilder::with_capacity(ca_in.len());9293for (value, (start, length)) in ca_in.iter().zip(offsets.offset_and_length_iter()) {94let mut is_in = false;95for i in 0..length {96is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord();97}98builder.push(is_in);99}100101let values = builder.freeze();102103let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);104BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])105} else {106let mut builder = BitmapBuilder::with_capacity(ca_in.len());107108for (value, (start, length)) in ca_in.iter().zip(offsets.offset_and_length_iter()) {109let mut is_in = false;110if value.is_some() {111for i in 0..length {112is_in |= value.to_total_ord() == inner.get(start + i).to_total_ord();113}114}115builder.push(is_in);116}117118let values = builder.freeze();119120let validity = match (validity, ca_in.rechunk_validity()) {121(None, None) => None,122(Some(v), None) | (None, Some(v)) => Some(v),123(Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)),124};125126let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);127BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])128}129}130};131ca.rename(ca_in.name().clone());132Ok(ca)133}134135#[cfg(feature = "dtype-array")]136fn is_in_helper_array_ca<'a, T>(137ca_in: &'a ChunkedArray<T>,138other: &'a ArrayChunked,139nulls_equal: bool,140) -> PolarsResult<BooleanChunked>141where142T: PolarsPhysicalType,143for<'b> T::Physical<'b>: TotalHash + TotalEq + ToTotalOrd + Copy,144for<'b> <T::Physical<'b> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,145{146let width = other.width();147let inner = other.get_inner();148let inner: &ChunkedArray<T> = inner.as_ref().as_ref();149let validity = other.rechunk_validity();150151let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 {152let value = ca_in.get(0);153154match value {155None if !nulls_equal => BooleanChunked::full_null(PlSmallStr::EMPTY, other.len()),156value => {157let mut builder = BitmapBuilder::with_capacity(other.len());158159for i in 0..other.len() {160let mut is_in = false;161for j in 0..width {162is_in |= value.to_total_ord() == inner.get(i * width + j).to_total_ord();163}164builder.push(is_in);165}166167let values = builder.freeze();168169let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);170BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])171},172}173} else {174assert_eq!(ca_in.len(), other.len());175{176if nulls_equal {177let mut builder = BitmapBuilder::with_capacity(ca_in.len());178179for (i, value) in ca_in.iter().enumerate() {180let mut is_in = false;181for j in 0..width {182is_in |= value.to_total_ord() == inner.get(i * width + j).to_total_ord();183}184builder.push(is_in);185}186187let values = builder.freeze();188189let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);190BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])191} else {192let mut builder = BitmapBuilder::with_capacity(ca_in.len());193194for (i, value) in ca_in.iter().enumerate() {195let mut is_in = false;196if value.is_some() {197for j in 0..width {198is_in |=199value.to_total_ord() == inner.get(i * width + j).to_total_ord();200}201}202builder.push(is_in);203}204205let values = builder.freeze();206207let validity = match (validity, ca_in.rechunk_validity()) {208(None, None) => None,209(Some(v), None) | (None, Some(v)) => Some(v),210(Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)),211};212213let result = BooleanArray::new(ArrowDataType::Boolean, values, validity);214BooleanChunked::from_chunk_iter(PlSmallStr::EMPTY, [result])215}216}217};218ca.rename(ca_in.name().clone());219Ok(ca)220}221222fn is_in_numeric<T>(223ca_in: &ChunkedArray<T>,224other: &Series,225nulls_equal: bool,226) -> PolarsResult<BooleanChunked>227where228T: PolarsNumericType,229T::Native: TotalHash + TotalEq + ToTotalOrd,230<T::Native as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,231{232match other.dtype() {233DataType::List(..) => {234let other = other.list()?;235if other.len() == 1 {236if other.has_nulls() {237return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));238}239240let other = other.explode(ExplodeOptions {241empty_as_null: false,242keep_nulls: true,243})?;244let other = other.as_ref().as_ref();245is_in_helper_ca(ca_in, other, nulls_equal)246} else {247is_in_helper_list_ca(ca_in, other, nulls_equal)248}249},250#[cfg(feature = "dtype-array")]251DataType::Array(..) => {252let other = other.array()?;253if other.len() == 1 {254if other.has_nulls() {255return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));256}257258let other = other.explode(ExplodeOptions {259empty_as_null: false,260keep_nulls: true,261})?;262let other = other.as_ref().as_ref();263is_in_helper_ca(ca_in, other, nulls_equal)264} else {265is_in_helper_array_ca(ca_in, other, nulls_equal)266}267},268_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),269}270}271272fn is_in_string(273ca_in: &StringChunked,274other: &Series,275nulls_equal: bool,276) -> PolarsResult<BooleanChunked> {277let other = match other.dtype() {278DataType::List(dt) if dt.is_string() || dt.is_enum() || dt.is_categorical() => {279let other = other.list()?;280other281.apply_to_inner(&|mut s| {282if dt.is_enum() || dt.is_categorical() {283s = s.cast(&DataType::String)?;284}285let s = s.str()?;286Ok(s.as_binary().into_series())287})?288.into_series()289},290#[cfg(feature = "dtype-array")]291DataType::Array(dt, _) if dt.is_string() || dt.is_enum() || dt.is_categorical() => {292let other = other.array()?;293other294.apply_to_inner(&|mut s| {295if dt.is_enum() || dt.is_categorical() {296s = s.cast(&DataType::String)?;297}298Ok(s.str()?.as_binary().into_series())299})?300.into_series()301},302_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),303};304is_in_binary(&ca_in.as_binary(), &other, nulls_equal)305}306307fn is_in_binary(308ca_in: &BinaryChunked,309other: &Series,310nulls_equal: bool,311) -> PolarsResult<BooleanChunked> {312match other.dtype() {313DataType::List(dt) if DataType::Binary == **dt => {314let other = other.list()?;315if other.len() == 1 {316if other.has_nulls() {317return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));318}319320let other = other.explode(ExplodeOptions {321empty_as_null: false,322keep_nulls: true,323})?;324let other = other.binary()?;325is_in_helper_ca(ca_in, other, nulls_equal)326} else {327is_in_helper_list_ca(ca_in, other, nulls_equal)328}329},330#[cfg(feature = "dtype-array")]331DataType::Array(dt, _) if DataType::Binary == **dt => {332let other = other.array()?;333if other.len() == 1 {334if other.has_nulls() {335return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));336}337338let other = other.explode(ExplodeOptions {339empty_as_null: false,340keep_nulls: true,341})?;342let other = other.binary()?;343is_in_helper_ca(ca_in, other, nulls_equal)344} else {345is_in_helper_array_ca(ca_in, other, nulls_equal)346}347},348_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),349}350}351352fn is_in_boolean(353ca_in: &BooleanChunked,354other: &Series,355nulls_equal: bool,356) -> PolarsResult<BooleanChunked> {357fn is_in_boolean_broadcast(358ca_in: &BooleanChunked,359other: &BooleanChunked,360nulls_equal: bool,361) -> PolarsResult<BooleanChunked> {362let has_true = other.any();363let nc = other.null_count();364365let has_false = if nc == 0 {366!other.all()367} else {368(other.sum().unwrap() as usize + nc) != other.len()369};370let value_map = |v| if v { has_true } else { has_false };371if nulls_equal {372if other.has_nulls() {373// If the rhs has nulls, then nulls in the left set evaluates to true.374Ok(ca_in.apply(|opt_v| Some(opt_v.is_none_or(value_map))))375} else {376// The rhs has no nulls; nulls in the left evaluates to false.377Ok(ca_in.apply(|opt_v| Some(opt_v.is_some_and(value_map))))378}379} else {380Ok(ca_in381.apply_values(value_map)382.with_name(ca_in.name().clone()))383}384}385386match other.dtype() {387DataType::List(dt) if ca_in.dtype() == &**dt => {388let other = other.list()?;389if other.len() == 1 {390if other.has_nulls() {391return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));392}393394let other = other.explode(ExplodeOptions {395empty_as_null: false,396keep_nulls: true,397})?;398let other = other.bool()?;399is_in_boolean_broadcast(ca_in, other, nulls_equal)400} else {401is_in_helper_list_ca(ca_in, other, nulls_equal)402}403},404#[cfg(feature = "dtype-array")]405DataType::Array(dt, _) if ca_in.dtype() == &**dt => {406let other = other.array()?;407if other.len() == 1 {408if other.has_nulls() {409return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));410}411412let other = other.explode(ExplodeOptions {413empty_as_null: false,414keep_nulls: true,415})?;416let other = other.bool()?;417is_in_boolean_broadcast(ca_in, other, nulls_equal)418} else {419is_in_helper_array_ca(ca_in, other, nulls_equal)420}421},422_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),423}424}425426#[cfg(feature = "dtype-categorical")]427fn is_in_cat_and_enum<T: PolarsCategoricalType>(428ca_in: &CategoricalChunked<T>,429other: &Series,430nulls_equal: bool,431) -> PolarsResult<BooleanChunked>432where433T::Native: ToTotalOrd<TotalOrdItem = T::Native>,434{435let to_categories = match (ca_in.dtype(), other.dtype().inner_dtype().unwrap()) {436(DataType::Enum(_, mapping) | DataType::Categorical(_, mapping), DataType::String) => {437(&|s: Series| {438let ca = s.str()?;439let ca: ChunkedArray<T::PolarsPhysical> = ca440.iter()441.flat_map(|opt_s| {442if let Some(s) = opt_s {443Some(mapping.get_cat(s).map(T::Native::from_cat))444} else {445Some(None)446}447})448.collect_ca(PlSmallStr::EMPTY);449Ok(ca.into_series())450}) as _451},452(DataType::Categorical(lcats, _), DataType::Categorical(rcats, _)) => {453ensure_same_categories(lcats, rcats)?;454(&|s: Series| Ok(s.cat::<T>()?.physical().clone().into_series())) as _455},456(DataType::Enum(lfcats, _), DataType::Enum(rfcats, _)) => {457ensure_same_frozen_categories(lfcats, rfcats)?;458(&|s: Series| Ok(s.cat::<T>()?.physical().clone().into_series())) as _459},460_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),461};462463let other = match other.dtype() {464DataType::List(_) => other.list()?.apply_to_inner(to_categories)?.into_series(),465#[cfg(feature = "dtype-array")]466DataType::Array(_, _) => other.array()?.apply_to_inner(to_categories)?.into_series(),467_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),468};469470is_in_numeric(ca_in.physical(), &other, nulls_equal)471}472473fn is_in_null(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult<BooleanChunked> {474if nulls_equal {475let ca_in = s.null()?;476Ok(match other.dtype() {477DataType::List(_) => {478let other = other.list()?;479if other.len() == 1 {480if other.has_nulls() {481return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));482}483484let other = other.explode(ExplodeOptions {485empty_as_null: false,486keep_nulls: true,487})?;488BooleanChunked::from_iter_values(489ca_in.name().clone(),490std::iter::repeat_n(other.has_nulls(), ca_in.len()),491)492} else {493other.apply_amortized_generic(|opt_s| {494Some(opt_s.map(|s| s.as_ref().has_nulls()) == Some(true))495})496}497},498#[cfg(feature = "dtype-array")]499DataType::Array(_, _) => {500let other = other.array()?;501if other.len() == 1 {502if other.has_nulls() {503return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));504}505506let other = other.explode(ExplodeOptions {507empty_as_null: false,508keep_nulls: true,509})?;510BooleanChunked::from_iter_values(511ca_in.name().clone(),512std::iter::repeat_n(other.has_nulls(), ca_in.len()),513)514} else {515other.apply_amortized_generic(|opt_s| {516Some(opt_s.map(|s| s.as_ref().has_nulls()) == Some(true))517})518}519},520_ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()),521})522} else {523let out = s.cast(&DataType::Boolean)?;524let ca_bool = out.bool()?.clone();525Ok(ca_bool)526}527}528529#[cfg(feature = "dtype-decimal")]530fn is_in_decimal(531ca_in: &DecimalChunked,532other: &Series,533nulls_equal: bool,534) -> PolarsResult<BooleanChunked> {535let Some(DataType::Decimal(other_precision, other_scale)) = other.dtype().inner_dtype() else {536polars_bail!(opq = is_in, ca_in.dtype(), other.dtype());537};538let prec = ca_in.precision().max(*other_precision);539let scale = ca_in.scale().max(*other_scale);540541// We convert both sides to a common scale, mapping any out-of-range values to unique integers,542// allowing us to then use is_in on the integer representation.543let sentinel_in = i128::MAX;544let sentinel_other = i128::MAX - 1;545let ca_in_phys = ca_in.into_phys_with_prec_scale_or_sentinel(prec, scale, sentinel_in);546547match other.dtype() {548DataType::List(_) => {549let other = other.list()?;550let other = other.apply_to_inner(&|s| {551let s = s.decimal()?;552let s = s.into_phys_with_prec_scale_or_sentinel(prec, scale, sentinel_other);553Ok(s.to_owned().into_series())554})?;555let other = other.into_series();556is_in_numeric(&ca_in_phys, &other, nulls_equal)557},558#[cfg(feature = "dtype-array")]559DataType::Array(_, _) => {560let other = other.array()?;561let other = other.apply_to_inner(&|s| {562let s = s.decimal()?;563let s = s.into_phys_with_prec_scale_or_sentinel(prec, scale, sentinel_other);564Ok(s.to_owned().into_series())565})?;566let other = other.into_series();567is_in_numeric(&ca_in_phys, &other, nulls_equal)568},569_ => unreachable!(),570}571}572573fn is_in_row_encoded(574s: &Series,575other: &Series,576nulls_equal: bool,577) -> PolarsResult<BooleanChunked> {578let ca_in = _get_rows_encoded_ca_unordered(s.name().clone(), &[s.clone().into_column()])?;579let mut mask = match other.dtype() {580DataType::List(_) => {581let other = other.list()?;582let other = other.apply_to_inner(&|s| {583Ok(584_get_rows_encoded_ca_unordered(s.name().clone(), &[s.into_column()])?585.into_series(),586)587})?;588if other.len() == 1 {589if other.has_nulls() {590return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));591}592593let other = other.explode(ExplodeOptions {594empty_as_null: false,595keep_nulls: true,596})?;597let other = other.binary_offset()?;598is_in_helper_ca(&ca_in, other, nulls_equal)599} else {600is_in_helper_list_ca(&ca_in, &other, nulls_equal)601}602},603#[cfg(feature = "dtype-array")]604DataType::Array(_, _) => {605let other = other.array()?;606let other = other.apply_to_inner(&|s| {607Ok(608_get_rows_encoded_ca_unordered(s.name().clone(), &[s.into_column()])?609.into_series(),610)611})?;612if other.len() == 1 {613if other.has_nulls() {614return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len()));615}616617let other = other.explode(ExplodeOptions {618empty_as_null: false,619keep_nulls: true,620})?;621let other = other.binary_offset()?;622is_in_helper_ca(&ca_in, other, nulls_equal)623} else {624is_in_helper_array_ca(&ca_in, &other, nulls_equal)625}626},627_ => unreachable!(),628}?;629630let mut validity = other.rechunk_validity();631if !nulls_equal {632validity = match (validity, s.rechunk_validity()) {633(None, None) => None,634(Some(v), None) | (None, Some(v)) => Some(v),635(Some(l), Some(r)) => Some(arrow::bitmap::and(&l, &r)),636};637}638639assert_eq!(mask.null_count(), 0);640mask.with_validities(&[validity]);641642Ok(mask)643}644645pub fn is_in(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult<BooleanChunked> {646polars_ensure!(647s.len() == other.len() || s.len() == 1 || other.len() == 1,648length_mismatch = "is_in",649s.len(),650other.len()651);652653#[allow(unused_mut)]654let mut other_is_valid_type = matches!(other.dtype(), DataType::List(_));655#[cfg(feature = "dtype-array")]656{657other_is_valid_type |= matches!(other.dtype(), DataType::Array(..))658}659polars_ensure!(other_is_valid_type, opq = is_in, s.dtype(), other.dtype());660661match s.dtype() {662#[cfg(feature = "dtype-categorical")]663dt @ DataType::Categorical(_, _) | dt @ DataType::Enum(_, _) => {664with_match_categorical_physical_type!(dt.cat_physical().unwrap(), |$C| {665is_in_cat_and_enum(s.cat::<$C>().unwrap(), other, nulls_equal)666})667},668DataType::String => {669let ca = s.str().unwrap();670is_in_string(ca, other, nulls_equal)671},672DataType::Binary => {673let ca = s.binary().unwrap();674is_in_binary(ca, other, nulls_equal)675},676DataType::Boolean => {677let ca = s.bool().unwrap();678is_in_boolean(ca, other, nulls_equal)679},680DataType::Null => is_in_null(s, other, nulls_equal),681#[cfg(feature = "dtype-decimal")]682DataType::Decimal(_, _) => {683let ca_in = s.decimal()?;684is_in_decimal(ca_in, other, nulls_equal)685},686dt if dt.is_nested() => is_in_row_encoded(s, other, nulls_equal),687dt if dt.to_physical().is_primitive_numeric() => {688let s = s.to_physical_repr();689let other = other.to_physical_repr();690let other = other.as_ref();691with_match_physical_numeric_polars_type!(s.dtype(), |$T| {692let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();693is_in_numeric(ca, other, nulls_equal)694})695},696dt => polars_bail!(opq = is_in, dt),697}698}699700701