Path: blob/main/py-polars/tests/unit/operations/test_comparison.py
6939 views
from __future__ import annotations12import math3from contextlib import nullcontext4from typing import TYPE_CHECKING, Any56import pytest78import polars as pl9from polars.exceptions import ComputeError10from polars.testing import assert_frame_equal, assert_series_equal1112if TYPE_CHECKING:13from contextlib import AbstractContextManager as ContextManager1415from polars._typing import PolarsDataType161718def test_comparison_order_null_broadcasting() -> None:19# see more: 818320exprs = [21pl.col("v") < pl.col("null"),22pl.col("null") < pl.col("v"),23pl.col("v") <= pl.col("null"),24pl.col("null") <= pl.col("v"),25pl.col("v") > pl.col("null"),26pl.col("null") > pl.col("v"),27pl.col("v") >= pl.col("null"),28pl.col("null") >= pl.col("v"),29]3031kwargs = {f"out{i}": e for i, e in zip(range(len(exprs)), exprs)}3233# single value, hits broadcasting branch34df = pl.DataFrame({"v": [42], "null": [None]})35assert all((df.select(**kwargs).null_count() == 1).rows()[0])3637# multiple values, hits default branch38df = pl.DataFrame({"v": [42, 42], "null": [None, None]})39assert all((df.select(**kwargs).null_count() == 2).rows()[0])404142def test_comparison_nulls_single() -> None:43df1 = pl.DataFrame(44{45"a": pl.Series([None], dtype=pl.String),46"b": pl.Series([None], dtype=pl.Int64),47"c": pl.Series([None], dtype=pl.Boolean),48}49)50df2 = pl.DataFrame(51{52"a": pl.Series([None], dtype=pl.String),53"b": pl.Series([None], dtype=pl.Int64),54"c": pl.Series([None], dtype=pl.Boolean),55}56)57assert (df1 == df2).row(0) == (None, None, None)58assert (df1 != df2).row(0) == (None, None, None)596061def test_comparison_series_expr() -> None:62df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})6364assert_frame_equal(65df.select(66(df["a"] == pl.col("b")).alias("eq"), # False, False, True67(df["a"] != pl.col("b")).alias("ne"), # True, True, False68(df["a"] < pl.col("b")).alias("lt"), # True, False, False69(df["a"] <= pl.col("b")).alias("le"), # True, False, True70(df["a"] > pl.col("b")).alias("gt"), # False, True, False71(df["a"] >= pl.col("b")).alias("ge"), # False, True, True72),73pl.DataFrame(74{75"eq": [False, False, True],76"ne": [True, True, False],77"lt": [True, False, False],78"le": [True, False, True],79"gt": [False, True, False],80"ge": [False, True, True],81}82),83)848586def test_comparison_expr_expr() -> None:87df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})8889assert_frame_equal(90df.select(91(pl.col("a") == pl.col("b")).alias("eq"), # False, False, True92(pl.col("a") != pl.col("b")).alias("ne"), # True, True, False93(pl.col("a") < pl.col("b")).alias("lt"), # True, False, False94(pl.col("a") <= pl.col("b")).alias("le"), # True, False, True95(pl.col("a") > pl.col("b")).alias("gt"), # False, True, False96(pl.col("a") >= pl.col("b")).alias("ge"), # False, True, True97),98pl.DataFrame(99{100"eq": [False, False, True],101"ne": [True, True, False],102"lt": [True, False, False],103"le": [True, False, True],104"gt": [False, True, False],105"ge": [False, True, True],106}107),108)109110111def test_comparison_expr_series() -> None:112df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})113114assert_frame_equal(115df.select(116(pl.col("a") == df["b"]).alias("eq"), # False, False, True117(pl.col("a") != df["b"]).alias("ne"), # True, True, False118(pl.col("a") < df["b"]).alias("lt"), # True, False, False119(pl.col("a") <= df["b"]).alias("le"), # True, False, True120(pl.col("a") > df["b"]).alias("gt"), # False, True, False121(pl.col("a") >= df["b"]).alias("ge"), # False, True, True122),123pl.DataFrame(124{125"eq": [False, False, True],126"ne": [True, True, False],127"lt": [True, False, False],128"le": [True, False, True],129"gt": [False, True, False],130"ge": [False, True, True],131}132),133)134135136def test_offset_handling_arg_where_7863() -> None:137df_check = pl.DataFrame({"a": [0, 1]})138df_check.select((pl.lit(0).append(pl.col("a")).append(0)) != 0)139assert (140df_check.select((pl.lit(0).append(pl.col("a")).append(0)) != 0)141.select(pl.col("literal").arg_true())142.item()143== 2144)145146147def test_missing_equality_on_bools() -> None:148df = pl.DataFrame(149{150"a": [True, None, False],151}152)153154assert df.select(pl.col("a").ne_missing(True))["a"].to_list() == [False, True, True]155assert df.select(pl.col("a").ne_missing(False))["a"].to_list() == [156True,157True,158False,159]160161162def test_struct_equality_18870() -> None:163s = pl.Series([{"a": 1}, None])164165# eq166result = s.eq(s).to_list()167expected = [True, None]168assert result == expected169170# ne171result = s.ne(s).to_list()172expected = [False, None]173assert result == expected174175# eq_missing176result = s.eq_missing(s).to_list()177expected = [True, True]178assert result == expected179180# ne_missing181result = s.ne_missing(s).to_list()182expected = [False, False]183assert result == expected184185186def test_struct_nested_equality() -> None:187df = pl.DataFrame(188{189"a": [{"foo": 0, "bar": "1"}, {"foo": None, "bar": "1"}, None],190"b": [{"foo": 0, "bar": "1"}] * 3,191}192)193194# eq195ans = df.select(pl.col("a").eq(pl.col("b")))196expected = pl.DataFrame({"a": [True, False, None]})197assert_frame_equal(ans, expected)198199# ne200ans = df.select(pl.col("a").ne(pl.col("b")))201expected = pl.DataFrame({"a": [False, True, None]})202assert_frame_equal(ans, expected)203204205def isnan(x: Any) -> bool:206return isinstance(x, float) and math.isnan(x)207208209def reference_ordering_propagating(lhs: Any, rhs: Any) -> str | None:210# normal < nan, nan == nan, nulls propagate211if lhs is None or rhs is None:212return None213214if isnan(lhs) and isnan(rhs):215return "="216217if isnan(lhs) or lhs > rhs:218return ">"219220if isnan(rhs) or lhs < rhs:221return "<"222223return "="224225226def reference_ordering_missing(lhs: Any, rhs: Any) -> str:227# null < normal < nan, nan == nan, null == null228if lhs is None and rhs is None:229return "="230231if lhs is None:232return "<"233234if rhs is None:235return ">"236237if isnan(lhs) and isnan(rhs):238return "="239240if isnan(lhs) or lhs > rhs:241return ">"242243if isnan(rhs) or lhs < rhs:244return "<"245246return "="247248249def verify_total_ordering(250lhs: Any, rhs: Any, dummy: Any, ldtype: PolarsDataType, rdtype: PolarsDataType251) -> None:252ref = reference_ordering_propagating(lhs, rhs)253refmiss = reference_ordering_missing(lhs, rhs)254255# Add dummy variable so we don't broadcast or do full-null optimization.256assert dummy is not None257df = pl.DataFrame(258{"l": [lhs, dummy], "r": [rhs, dummy]}, schema={"l": ldtype, "r": rdtype}259)260261ans = df.select(262(pl.col("l") == pl.col("r")).alias("eq"),263(pl.col("l") != pl.col("r")).alias("ne"),264(pl.col("l") < pl.col("r")).alias("lt"),265(pl.col("l") <= pl.col("r")).alias("le"),266(pl.col("l") > pl.col("r")).alias("gt"),267(pl.col("l") >= pl.col("r")).alias("ge"),268pl.col("l").eq_missing(pl.col("r")).alias("eq_missing"),269pl.col("l").ne_missing(pl.col("r")).alias("ne_missing"),270)271272ans_correct_dict = {273"eq": [ref and ref == "="], # "ref and X" propagates ref is None274"ne": [ref and ref != "="],275"lt": [ref and ref == "<"],276"le": [ref and (ref == "<" or ref == "=")],277"gt": [ref and ref == ">"],278"ge": [ref and (ref == ">" or ref == "=")],279"eq_missing": [refmiss == "="],280"ne_missing": [refmiss != "="],281}282ans_correct = pl.DataFrame(283ans_correct_dict, schema=dict.fromkeys(ans_correct_dict, pl.Boolean)284)285286assert_frame_equal(ans[:1], ans_correct)287288289def verify_total_ordering_broadcast(290lhs: Any, rhs: Any, dummy: Any, ldtype: PolarsDataType, rdtype: PolarsDataType291) -> None:292ref = reference_ordering_propagating(lhs, rhs)293refmiss = reference_ordering_missing(lhs, rhs)294295# Add dummy variable so we don't broadcast inherently.296assert dummy is not None297df = pl.DataFrame(298{"l": [lhs, dummy], "r": [rhs, dummy]}, schema={"l": ldtype, "r": rdtype}299)300301ans_first = df.select(302(pl.col("l") == pl.col("r").first()).alias("eq"),303(pl.col("l") != pl.col("r").first()).alias("ne"),304(pl.col("l") < pl.col("r").first()).alias("lt"),305(pl.col("l") <= pl.col("r").first()).alias("le"),306(pl.col("l") > pl.col("r").first()).alias("gt"),307(pl.col("l") >= pl.col("r").first()).alias("ge"),308pl.col("l").eq_missing(pl.col("r").first()).alias("eq_missing"),309pl.col("l").ne_missing(pl.col("r").first()).alias("ne_missing"),310)311312ans_scalar = df.select(313(pl.col("l") == rhs).alias("eq"),314(pl.col("l") != rhs).alias("ne"),315(pl.col("l") < rhs).alias("lt"),316(pl.col("l") <= rhs).alias("le"),317(pl.col("l") > rhs).alias("gt"),318(pl.col("l") >= rhs).alias("ge"),319(pl.col("l").eq_missing(rhs)).alias("eq_missing"),320(pl.col("l").ne_missing(rhs)).alias("ne_missing"),321)322323ans_correct_dict = {324"eq": [ref and ref == "="], # "ref and X" propagates ref is None325"ne": [ref and ref != "="],326"lt": [ref and ref == "<"],327"le": [ref and (ref == "<" or ref == "=")],328"gt": [ref and ref == ">"],329"ge": [ref and (ref == ">" or ref == "=")],330"eq_missing": [refmiss == "="],331"ne_missing": [refmiss != "="],332}333ans_correct = pl.DataFrame(334ans_correct_dict, schema=dict.fromkeys(ans_correct_dict, pl.Boolean)335)336337assert_frame_equal(ans_first[:1], ans_correct)338assert_frame_equal(ans_scalar[:1], ans_correct)339340341INTERESTING_FLOAT_VALUES = [3420.0,343-0.0,344-1.0,3451.0,346-float("nan"),347float("nan"),348-float("inf"),349float("inf"),350None,351]352353354@pytest.mark.slow355@pytest.mark.parametrize("lhs", INTERESTING_FLOAT_VALUES)356@pytest.mark.parametrize("rhs", INTERESTING_FLOAT_VALUES)357def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> None:358verify_total_ordering(lhs, rhs, 0.0, pl.Float32, pl.Float32)359verify_total_ordering(lhs, rhs, 0.0, pl.Float64, pl.Float32)360context: pytest.WarningsRecorder | ContextManager[None] = (361pytest.warns(UserWarning) if rhs is None else nullcontext()362)363with context:364verify_total_ordering_broadcast(lhs, rhs, 0.0, pl.Float32, pl.Float32)365verify_total_ordering_broadcast(lhs, rhs, 0.0, pl.Float64, pl.Float32)366367368INTERESTING_STRING_VALUES = [369"",370"foo",371"bar",372"fooo",373"fooooooooooo",374"foooooooooooo",375"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooom",376"foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo",377"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooop",378None,379]380381382@pytest.mark.slow383@pytest.mark.parametrize("lhs", INTERESTING_STRING_VALUES)384@pytest.mark.parametrize("rhs", INTERESTING_STRING_VALUES)385def test_total_ordering_string_series(lhs: str | None, rhs: str | None) -> None:386verify_total_ordering(lhs, rhs, "", pl.String, pl.String)387context: pytest.WarningsRecorder | ContextManager[None] = (388pytest.warns(UserWarning) if rhs is None else nullcontext()389)390with context:391verify_total_ordering_broadcast(lhs, rhs, "", pl.String, pl.String)392393394@pytest.mark.slow395@pytest.mark.parametrize("lhs", INTERESTING_STRING_VALUES)396@pytest.mark.parametrize("rhs", INTERESTING_STRING_VALUES)397@pytest.mark.parametrize("fresh_cat", [False, True])398def test_total_ordering_cat_series(399lhs: str | None, rhs: str | None, fresh_cat: bool400) -> None:401if fresh_cat:402c = [pl.Categorical(pl.Categories.random()) for _ in range(6)]403else:404c = [pl.Categorical() for _ in range(6)]405verify_total_ordering(lhs, rhs, "", c[0], c[0])406verify_total_ordering(lhs, rhs, "", pl.String, c[1])407verify_total_ordering(lhs, rhs, "", c[2], pl.String)408context: pytest.WarningsRecorder | ContextManager[None] = (409pytest.warns(UserWarning) if rhs is None else nullcontext()410)411with context:412verify_total_ordering_broadcast(lhs, rhs, "", c[3], c[3])413verify_total_ordering_broadcast(lhs, rhs, "", pl.String, c[4])414verify_total_ordering_broadcast(lhs, rhs, "", c[5], pl.String)415416417@pytest.mark.slow418@pytest.mark.parametrize("str_lhs", INTERESTING_STRING_VALUES)419@pytest.mark.parametrize("str_rhs", INTERESTING_STRING_VALUES)420def test_total_ordering_binary_series(str_lhs: str | None, str_rhs: str | None) -> None:421lhs = None if str_lhs is None else str_lhs.encode("utf-8")422rhs = None if str_rhs is None else str_rhs.encode("utf-8")423verify_total_ordering(lhs, rhs, b"", pl.Binary, pl.Binary)424context: pytest.WarningsRecorder | ContextManager[None] = (425pytest.warns(UserWarning) if rhs is None else nullcontext()426)427with context:428verify_total_ordering_broadcast(lhs, rhs, b"", pl.Binary, pl.Binary)429430431@pytest.mark.parametrize("lhs", [None, False, True])432@pytest.mark.parametrize("rhs", [None, False, True])433def test_total_ordering_bool_series(lhs: bool | None, rhs: bool | None) -> None:434verify_total_ordering(lhs, rhs, False, pl.Boolean, pl.Boolean)435context: pytest.WarningsRecorder | ContextManager[None] = (436pytest.warns(UserWarning) if rhs is None else nullcontext()437)438with context:439verify_total_ordering_broadcast(lhs, rhs, False, pl.Boolean, pl.Boolean)440441442def test_cat_compare_with_bool() -> None:443data = pl.DataFrame([pl.Series("col1", ["a", "b"], dtype=pl.Categorical)])444445with pytest.raises(ComputeError, match="cannot compare categorical with bool"):446data.filter(pl.col("col1") == True) # noqa: E712447448449def test_schema_ne_missing_9256() -> None:450df = pl.DataFrame({"a": [0, 1, None], "b": [True, False, True]})451452assert df.select(pl.col("a").ne_missing(0).or_(pl.col("b")))["a"].all()453454455def test_nested_binary_literal_super_type_12227() -> None:456# The `.alias` is important here to trigger the bug.457result = pl.select(x=1).select((pl.lit(0) + ((pl.col("x") > 0) * 0.1)).alias("x"))458assert result.item() == 0.1459460result = pl.select((pl.lit(0) + (pl.lit(0) == pl.lit(0)) * pl.lit(0.1)) + pl.lit(0))461assert result.item() == 0.1462463464def test_struct_broadcasting_comparison() -> None:465df = pl.DataFrame({"foo": [{"a": 1}, {"a": 2}, {"a": 1}]})466assert df.select(eq=pl.col.foo == pl.col.foo.last()).to_dict(as_series=False) == {467"eq": [True, False, True]468}469470471@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 1)])472def test_compare_list_broadcast_empty_first_chunk_20165(dtype: pl.DataType) -> None:473s = pl.concat(2 * [pl.Series([[1]], dtype=dtype)]).filter([False, True])474475assert s.len() == 1476assert s.n_chunks() == 2477478assert_series_equal(479pl.select(pl.lit(pl.Series([[1], [2]]), dtype=dtype) == pl.lit(s)).to_series(),480pl.Series([True, False]),481)482483484