Path: blob/main/crates/polars-core/src/series/comparison.rs
6940 views
//! Comparison operations on Series.12use polars_error::feature_gated;34use crate::prelude::*;5use crate::series::arithmetic::coerce_lhs_rhs;6use crate::series::nulls::replace_non_null;78macro_rules! impl_eq_compare {9($self:expr, $rhs:expr, $method:ident) => {{10use DataType::*;11let (lhs, rhs) = ($self, $rhs);12validate_types(lhs.dtype(), rhs.dtype())?;1314polars_ensure!(15lhs.len() == rhs.len() ||1617// Broadcast18lhs.len() == 1 ||19rhs.len() == 1,20ShapeMismatch: "could not compare between two series of different length ({} != {})",21lhs.len(),22rhs.len()23);2425#[cfg(feature = "dtype-categorical")]26match (lhs.dtype(), rhs.dtype()) {27(Categorical(lcats, _), Categorical(rcats, _)) => {28ensure_same_categories(lcats, rcats)?;29return with_match_categorical_physical_type!(lcats.physical(), |$C| {30lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())31})32},33(Enum(lfcats, _), Enum(rfcats, _)) => {34ensure_same_frozen_categories(lfcats, rfcats)?;35return with_match_categorical_physical_type!(lfcats.physical(), |$C| {36lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())37})38},39(Categorical(_, _) | Enum(_, _), String) => {40return with_match_categorical_physical_type!(lhs.dtype().cat_physical().unwrap(), |$C| {41Ok(lhs.cat::<$C>().unwrap().$method(rhs.str().unwrap()))42})43},44(String, Categorical(_, _) | Enum(_, _)) => {45return with_match_categorical_physical_type!(rhs.dtype().cat_physical().unwrap(), |$C| {46Ok(rhs.cat::<$C>().unwrap().$method(lhs.str().unwrap()))47})48},49_ => (),50};5152let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs)53.map_err(|_| polars_err!(54SchemaMismatch: "could not evaluate comparison between series '{}' of dtype: {} and series '{}' of dtype: {}",55lhs.name(), lhs.dtype(), rhs.name(), rhs.dtype()56))?;57let lhs = lhs.to_physical_repr();58let rhs = rhs.to_physical_repr();59let mut out = match lhs.dtype() {60Null => lhs.null().unwrap().$method(rhs.null().unwrap()),61Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),62String => lhs.str().unwrap().$method(rhs.str().unwrap()),63Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()),64BinaryOffset => lhs.binary_offset().unwrap().$method(rhs.binary_offset().unwrap()),65UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())),66UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())),67UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),68UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),69Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())),70Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())),71Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),72Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),73Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())),74Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),75Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),76List(_) => lhs.list().unwrap().$method(rhs.list().unwrap()),77#[cfg(feature = "dtype-array")]78Array(_, _) => lhs.array().unwrap().$method(rhs.array().unwrap()),79#[cfg(feature = "dtype-struct")]80Struct(_) => lhs.struct_().unwrap().$method(rhs.struct_().unwrap()),8182dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),83};84out.rename(lhs.name().clone());85PolarsResult::Ok(out)86}};87}8889macro_rules! bail_invalid_ineq {90($lhs:expr, $rhs:expr, $op:literal) => {91polars_bail!(92InvalidOperation: "cannot perform '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}",93$op,94$lhs.name(), $lhs.dtype(),95$rhs.name(), $rhs.dtype(),96)97};98}99100macro_rules! impl_ineq_compare {101($self:expr, $rhs:expr, $method:ident, $op:literal, $rev_method:ident) => {{102use DataType::*;103let (lhs, rhs) = ($self, $rhs);104validate_types(lhs.dtype(), rhs.dtype())?;105106polars_ensure!(107lhs.len() == rhs.len() ||108109// Broadcast110lhs.len() == 1 ||111rhs.len() == 1,112ShapeMismatch:113"could not perform '{}' comparison between series '{}' of length: {} and series '{}' of length: {}, because they have different lengths",114$op,115lhs.name(), lhs.len(),116rhs.name(), rhs.len()117);118119#[cfg(feature = "dtype-categorical")]120match (lhs.dtype(), rhs.dtype()) {121(Categorical(lcats, _), Categorical(rcats, _)) => {122ensure_same_categories(lcats, rcats)?;123return with_match_categorical_physical_type!(lcats.physical(), |$C| {124lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())125})126},127(Enum(lfcats, _), Enum(rfcats, _)) => {128ensure_same_frozen_categories(lfcats, rfcats)?;129return with_match_categorical_physical_type!(lfcats.physical(), |$C| {130lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())131})132},133(Categorical(_, _) | Enum(_, _), String) => {134return with_match_categorical_physical_type!(lhs.dtype().cat_physical().unwrap(), |$C| {135lhs.cat::<$C>().unwrap().$method(rhs.str().unwrap())136})137},138(String, Categorical(_, _) | Enum(_, _)) => {139return with_match_categorical_physical_type!(rhs.dtype().cat_physical().unwrap(), |$C| {140// We use the reverse method as string <-> enum comparisons are only implemented one-way.141rhs.cat::<$C>().unwrap().$rev_method(lhs.str().unwrap())142})143},144_ => (),145};146147let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs).map_err(|_|148polars_err!(149SchemaMismatch: "could not evaluate '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}",150$op,151lhs.name(), lhs.dtype(),152rhs.name(), rhs.dtype()153)154)?;155let lhs = lhs.to_physical_repr();156let rhs = rhs.to_physical_repr();157let mut out = match lhs.dtype() {158Null => lhs.null().unwrap().$method(rhs.null().unwrap()),159Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),160String => lhs.str().unwrap().$method(rhs.str().unwrap()),161Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()),162BinaryOffset => lhs.binary_offset().unwrap().$method(rhs.binary_offset().unwrap()),163UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())),164UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())),165UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),166UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),167Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())),168Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())),169Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),170Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),171Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())),172Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),173Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),174List(_) => bail_invalid_ineq!(lhs, rhs, $op),175#[cfg(feature = "dtype-array")]176Array(_, _) => bail_invalid_ineq!(lhs, rhs, $op),177#[cfg(feature = "dtype-struct")]178Struct(_) => bail_invalid_ineq!(lhs, rhs, $op),179180dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),181};182out.rename(lhs.name().clone());183PolarsResult::Ok(out)184}};185}186187fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> {188use DataType::*;189190match (left, right) {191(String, dt) | (dt, String) if dt.is_primitive_numeric() => {192polars_bail!(ComputeError: "cannot compare string with numeric type ({})", dt)193},194#[cfg(feature = "dtype-categorical")]195(Categorical(_, _) | Enum(_, _), dt) | (dt, Categorical(_, _) | Enum(_, _))196if !(dt.is_categorical() | dt.is_string() | dt.is_enum()) =>197{198polars_bail!(ComputeError: "cannot compare categorical with {}", dt);199},200_ => (),201};202Ok(())203}204205impl ChunkCompareEq<&Series> for Series {206type Item = PolarsResult<BooleanChunked>;207208/// Create a boolean mask by checking for equality.209fn equal(&self, rhs: &Series) -> Self::Item {210impl_eq_compare!(self, rhs, equal)211}212213/// Create a boolean mask by checking for equality.214fn equal_missing(&self, rhs: &Series) -> Self::Item {215impl_eq_compare!(self, rhs, equal_missing)216}217218/// Create a boolean mask by checking for inequality.219fn not_equal(&self, rhs: &Series) -> Self::Item {220impl_eq_compare!(self, rhs, not_equal)221}222223/// Create a boolean mask by checking for inequality.224fn not_equal_missing(&self, rhs: &Series) -> Self::Item {225impl_eq_compare!(self, rhs, not_equal_missing)226}227}228229impl ChunkCompareIneq<&Series> for Series {230type Item = PolarsResult<BooleanChunked>;231232/// Create a boolean mask by checking if self > rhs.233fn gt(&self, rhs: &Series) -> Self::Item {234impl_ineq_compare!(self, rhs, gt, ">", lt)235}236237/// Create a boolean mask by checking if self >= rhs.238fn gt_eq(&self, rhs: &Series) -> Self::Item {239impl_ineq_compare!(self, rhs, gt_eq, ">=", lt_eq)240}241242/// Create a boolean mask by checking if self < rhs.243fn lt(&self, rhs: &Series) -> Self::Item {244impl_ineq_compare!(self, rhs, lt, "<", gt)245}246247/// Create a boolean mask by checking if self <= rhs.248fn lt_eq(&self, rhs: &Series) -> Self::Item {249impl_ineq_compare!(self, rhs, lt_eq, "<=", gt_eq)250}251}252253impl<Rhs> ChunkCompareEq<Rhs> for Series254where255Rhs: NumericNative,256{257type Item = PolarsResult<BooleanChunked>;258259fn equal(&self, rhs: Rhs) -> Self::Item {260validate_types(self.dtype(), &DataType::Int8)?;261let s = self.to_physical_repr();262Ok(apply_method_physical_numeric!(&s, equal, rhs))263}264265fn equal_missing(&self, rhs: Rhs) -> Self::Item {266validate_types(self.dtype(), &DataType::Int8)?;267let s = self.to_physical_repr();268Ok(apply_method_physical_numeric!(&s, equal_missing, rhs))269}270271fn not_equal(&self, rhs: Rhs) -> Self::Item {272validate_types(self.dtype(), &DataType::Int8)?;273let s = self.to_physical_repr();274Ok(apply_method_physical_numeric!(&s, not_equal, rhs))275}276277fn not_equal_missing(&self, rhs: Rhs) -> Self::Item {278validate_types(self.dtype(), &DataType::Int8)?;279let s = self.to_physical_repr();280Ok(apply_method_physical_numeric!(&s, not_equal_missing, rhs))281}282}283284impl<Rhs> ChunkCompareIneq<Rhs> for Series285where286Rhs: NumericNative,287{288type Item = PolarsResult<BooleanChunked>;289290fn gt(&self, rhs: Rhs) -> Self::Item {291validate_types(self.dtype(), &DataType::Int8)?;292let s = self.to_physical_repr();293Ok(apply_method_physical_numeric!(&s, gt, rhs))294}295296fn gt_eq(&self, rhs: Rhs) -> Self::Item {297validate_types(self.dtype(), &DataType::Int8)?;298let s = self.to_physical_repr();299Ok(apply_method_physical_numeric!(&s, gt_eq, rhs))300}301302fn lt(&self, rhs: Rhs) -> Self::Item {303validate_types(self.dtype(), &DataType::Int8)?;304let s = self.to_physical_repr();305Ok(apply_method_physical_numeric!(&s, lt, rhs))306}307308fn lt_eq(&self, rhs: Rhs) -> Self::Item {309validate_types(self.dtype(), &DataType::Int8)?;310let s = self.to_physical_repr();311Ok(apply_method_physical_numeric!(&s, lt_eq, rhs))312}313}314315impl ChunkCompareEq<&str> for Series {316type Item = PolarsResult<BooleanChunked>;317318fn equal(&self, rhs: &str) -> PolarsResult<BooleanChunked> {319validate_types(self.dtype(), &DataType::String)?;320match self.dtype() {321DataType::String => Ok(self.str().unwrap().equal(rhs)),322#[cfg(feature = "dtype-categorical")]323DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(324with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {325self.cat::<$C>().unwrap().equal(rhs)326}),327),328_ => Ok(BooleanChunked::full(self.name().clone(), false, self.len())),329}330}331332fn equal_missing(&self, rhs: &str) -> Self::Item {333validate_types(self.dtype(), &DataType::String)?;334match self.dtype() {335DataType::String => Ok(self.str().unwrap().equal_missing(rhs)),336#[cfg(feature = "dtype-categorical")]337DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(338with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {339self.cat::<$C>().unwrap().equal_missing(rhs)340}),341),342_ => Ok(replace_non_null(343self.name().clone(),344self.0.chunks(),345false,346)),347}348}349350fn not_equal(&self, rhs: &str) -> PolarsResult<BooleanChunked> {351validate_types(self.dtype(), &DataType::String)?;352match self.dtype() {353DataType::String => Ok(self.str().unwrap().not_equal(rhs)),354#[cfg(feature = "dtype-categorical")]355DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(356with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {357self.cat::<$C>().unwrap().not_equal(rhs)358}),359),360_ => Ok(BooleanChunked::full(self.name().clone(), true, self.len())),361}362}363364fn not_equal_missing(&self, rhs: &str) -> Self::Item {365validate_types(self.dtype(), &DataType::String)?;366match self.dtype() {367DataType::String => Ok(self.str().unwrap().not_equal_missing(rhs)),368#[cfg(feature = "dtype-categorical")]369DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(370with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {371self.cat::<$C>().unwrap().not_equal_missing(rhs)372}),373),374_ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)),375}376}377}378379impl ChunkCompareIneq<&str> for Series {380type Item = PolarsResult<BooleanChunked>;381382fn gt(&self, rhs: &str) -> Self::Item {383validate_types(self.dtype(), &DataType::String)?;384match self.dtype() {385DataType::String => Ok(self.str().unwrap().gt(rhs)),386#[cfg(feature = "dtype-categorical")]387DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(388with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {389self.cat::<$C>().unwrap().gt(rhs)390}),391),392_ => polars_bail!(393ComputeError: "cannot compare str value to series of type {}", self.dtype(),394),395}396}397398fn gt_eq(&self, rhs: &str) -> Self::Item {399validate_types(self.dtype(), &DataType::String)?;400match self.dtype() {401DataType::String => Ok(self.str().unwrap().gt_eq(rhs)),402#[cfg(feature = "dtype-categorical")]403DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(404with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {405self.cat::<$C>().unwrap().gt_eq(rhs)406}),407),408_ => polars_bail!(409ComputeError: "cannot compare str value to series of type {}", self.dtype(),410),411}412}413414fn lt(&self, rhs: &str) -> Self::Item {415validate_types(self.dtype(), &DataType::String)?;416match self.dtype() {417DataType::String => Ok(self.str().unwrap().lt(rhs)),418#[cfg(feature = "dtype-categorical")]419DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(420with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {421self.cat::<$C>().unwrap().lt(rhs)422}),423),424_ => polars_bail!(425ComputeError: "cannot compare str value to series of type {}", self.dtype(),426),427}428}429430fn lt_eq(&self, rhs: &str) -> Self::Item {431validate_types(self.dtype(), &DataType::String)?;432match self.dtype() {433DataType::String => Ok(self.str().unwrap().lt_eq(rhs)),434#[cfg(feature = "dtype-categorical")]435DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(436with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {437self.cat::<$C>().unwrap().lt_eq(rhs)438}),439),440_ => polars_bail!(441ComputeError: "cannot compare str value to series of type {}", self.dtype(),442),443}444}445}446447448