Path: blob/main/py-polars/tests/unit/operations/test_comparison.py
8421 views
from __future__ import annotations12import math3from contextlib import nullcontext4from typing import TYPE_CHECKING, Any56import pytest78import polars as pl9from polars.exceptions import ComputeError10from polars.testing import assert_frame_equal, assert_series_equal1112if TYPE_CHECKING:13from contextlib import AbstractContextManager as ContextManager1415from polars._typing import PolarsDataType161718def test_comparison_order_null_broadcasting() -> None:19# see more: 818320exprs = [21pl.col("v") < pl.col("null"),22pl.col("null") < pl.col("v"),23pl.col("v") <= pl.col("null"),24pl.col("null") <= pl.col("v"),25pl.col("v") > pl.col("null"),26pl.col("null") > pl.col("v"),27pl.col("v") >= pl.col("null"),28pl.col("null") >= pl.col("v"),29]3031kwargs = {f"out{i}": e for i, e in zip(range(len(exprs)), exprs, strict=True)}3233# single value, hits broadcasting branch34df = pl.DataFrame({"v": [42], "null": [None]})35assert all((df.select(**kwargs).null_count() == 1).rows()[0])3637# multiple values, hits default branch38df = pl.DataFrame({"v": [42, 42], "null": [None, None]})39assert all((df.select(**kwargs).null_count() == 2).rows()[0])404142def test_comparison_nulls_single() -> None:43df1 = pl.DataFrame(44{45"a": pl.Series([None], dtype=pl.String),46"b": pl.Series([None], dtype=pl.Int64),47"c": pl.Series([None], dtype=pl.Boolean),48}49)50df2 = pl.DataFrame(51{52"a": pl.Series([None], dtype=pl.String),53"b": pl.Series([None], dtype=pl.Int64),54"c": pl.Series([None], dtype=pl.Boolean),55}56)57assert (df1 == df2).row(0) == (None, None, None)58assert (df1 != df2).row(0) == (None, None, None)596061def test_comparison_series_expr() -> None:62df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})6364assert_frame_equal(65df.select(66(df["a"] == pl.col("b")).alias("eq"), # False, False, True67(df["a"] != pl.col("b")).alias("ne"), # True, True, False68(df["a"] < pl.col("b")).alias("lt"), # True, False, False69(df["a"] <= pl.col("b")).alias("le"), # True, False, True70(df["a"] > pl.col("b")).alias("gt"), # False, True, False71(df["a"] >= pl.col("b")).alias("ge"), # False, True, True72),73pl.DataFrame(74{75"eq": [False, False, True],76"ne": [True, True, False],77"lt": [True, False, False],78"le": [True, False, True],79"gt": [False, True, False],80"ge": [False, True, True],81}82),83)848586def test_comparison_expr_expr() -> None:87df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})8889assert_frame_equal(90df.select(91(pl.col("a") == pl.col("b")).alias("eq"), # False, False, True92(pl.col("a") != pl.col("b")).alias("ne"), # True, True, False93(pl.col("a") < pl.col("b")).alias("lt"), # True, False, False94(pl.col("a") <= pl.col("b")).alias("le"), # True, False, True95(pl.col("a") > pl.col("b")).alias("gt"), # False, True, False96(pl.col("a") >= pl.col("b")).alias("ge"), # False, True, True97),98pl.DataFrame(99{100"eq": [False, False, True],101"ne": [True, True, False],102"lt": [True, False, False],103"le": [True, False, True],104"gt": [False, True, False],105"ge": [False, True, True],106}107),108)109110111def test_comparison_expr_series() -> None:112df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})113114assert_frame_equal(115df.select(116(pl.col("a") == df["b"]).alias("eq"), # False, False, True117(pl.col("a") != df["b"]).alias("ne"), # True, True, False118(pl.col("a") < df["b"]).alias("lt"), # True, False, False119(pl.col("a") <= df["b"]).alias("le"), # True, False, True120(pl.col("a") > df["b"]).alias("gt"), # False, True, False121(pl.col("a") >= df["b"]).alias("ge"), # False, True, True122),123pl.DataFrame(124{125"eq": [False, False, True],126"ne": [True, True, False],127"lt": [True, False, False],128"le": [True, False, True],129"gt": [False, True, False],130"ge": [False, True, True],131}132),133)134135136def test_offset_handling_arg_where_7863() -> None:137df_check = pl.DataFrame({"a": [0, 1]})138df_check.select((pl.lit(0).append(pl.col("a")).append(0)) != 0)139assert (140df_check.select((pl.lit(0).append(pl.col("a")).append(0)) != 0)141.select(pl.col("literal").arg_true())142.item()143== 2144)145146147def test_missing_equality_on_bools() -> None:148df = pl.DataFrame(149{150"a": [True, None, False],151}152)153154assert df.select(pl.col("a").ne_missing(True))["a"].to_list() == [False, True, True]155assert df.select(pl.col("a").ne_missing(False))["a"].to_list() == [156True,157True,158False,159]160161162def test_struct_equality_18870() -> None:163s = pl.Series([{"a": 1}, None])164165# eq166result = s.eq(s).to_list()167expected = [True, None]168assert result == expected169170# ne171result = s.ne(s).to_list()172expected = [False, None]173assert result == expected174175# eq_missing176result = s.eq_missing(s).to_list()177expected = [True, True]178assert result == expected179180# ne_missing181result = s.ne_missing(s).to_list()182expected = [False, False]183assert result == expected184185186def test_struct_nested_equality() -> None:187df = pl.DataFrame(188{189"a": [{"foo": 0, "bar": "1"}, {"foo": None, "bar": "1"}, None],190"b": [{"foo": 0, "bar": "1"}] * 3,191}192)193194# eq195ans = df.select(pl.col("a").eq(pl.col("b")))196expected = pl.DataFrame({"a": [True, False, None]})197assert_frame_equal(ans, expected)198199# ne200ans = df.select(pl.col("a").ne(pl.col("b")))201expected = pl.DataFrame({"a": [False, True, None]})202assert_frame_equal(ans, expected)203204205def isnan(x: Any) -> bool:206return isinstance(x, float) and math.isnan(x)207208209def reference_ordering_propagating(lhs: Any, rhs: Any) -> str | None:210# normal < nan, nan == nan, nulls propagate211if lhs is None or rhs is None:212return None213214if isnan(lhs) and isnan(rhs):215return "="216217if isnan(lhs) or lhs > rhs:218return ">"219220if isnan(rhs) or lhs < rhs:221return "<"222223return "="224225226def reference_ordering_missing(lhs: Any, rhs: Any) -> str:227# null < normal < nan, nan == nan, null == null228if lhs is None and rhs is None:229return "="230231if lhs is None:232return "<"233234if rhs is None:235return ">"236237if isnan(lhs) and isnan(rhs):238return "="239240if isnan(lhs) or lhs > rhs:241return ">"242243if isnan(rhs) or lhs < rhs:244return "<"245246return "="247248249def verify_total_ordering(250lhs: Any, rhs: Any, dummy: Any, ldtype: PolarsDataType, rdtype: PolarsDataType251) -> None:252ref = reference_ordering_propagating(lhs, rhs)253refmiss = reference_ordering_missing(lhs, rhs)254255# Add dummy variable so we don't broadcast or do full-null optimization.256assert dummy is not None257df = pl.DataFrame(258{"l": [lhs, dummy], "r": [rhs, dummy]}, schema={"l": ldtype, "r": rdtype}259)260261ans = df.select(262(pl.col("l") == pl.col("r")).alias("eq"),263(pl.col("l") != pl.col("r")).alias("ne"),264(pl.col("l") < pl.col("r")).alias("lt"),265(pl.col("l") <= pl.col("r")).alias("le"),266(pl.col("l") > pl.col("r")).alias("gt"),267(pl.col("l") >= pl.col("r")).alias("ge"),268pl.col("l").eq_missing(pl.col("r")).alias("eq_missing"),269pl.col("l").ne_missing(pl.col("r")).alias("ne_missing"),270)271272ans_correct_dict = {273"eq": [ref and ref == "="], # "ref and X" propagates ref is None274"ne": [ref and ref != "="],275"lt": [ref and ref == "<"],276"le": [ref and (ref == "<" or ref == "=")],277"gt": [ref and ref == ">"],278"ge": [ref and (ref == ">" or ref == "=")],279"eq_missing": [refmiss == "="],280"ne_missing": [refmiss != "="],281}282ans_correct = pl.DataFrame(283ans_correct_dict, schema=dict.fromkeys(ans_correct_dict, pl.Boolean)284)285286assert_frame_equal(ans[:1], ans_correct)287288289def verify_total_ordering_broadcast(290lhs: Any, rhs: Any, dummy: Any, ldtype: PolarsDataType, rdtype: PolarsDataType291) -> None:292ref = reference_ordering_propagating(lhs, rhs)293refmiss = reference_ordering_missing(lhs, rhs)294295# Add dummy variable so we don't broadcast inherently.296assert dummy is not None297df = pl.DataFrame(298{"l": [lhs, dummy], "r": [rhs, dummy]}, schema={"l": ldtype, "r": rdtype}299)300301ans_first = df.select(302(pl.col("l") == pl.col("r").first()).alias("eq"),303(pl.col("l") != pl.col("r").first()).alias("ne"),304(pl.col("l") < pl.col("r").first()).alias("lt"),305(pl.col("l") <= pl.col("r").first()).alias("le"),306(pl.col("l") > pl.col("r").first()).alias("gt"),307(pl.col("l") >= pl.col("r").first()).alias("ge"),308pl.col("l").eq_missing(pl.col("r").first()).alias("eq_missing"),309pl.col("l").ne_missing(pl.col("r").first()).alias("ne_missing"),310)311312ans_scalar = df.select(313(pl.col("l") == rhs).alias("eq"),314(pl.col("l") != rhs).alias("ne"),315(pl.col("l") < rhs).alias("lt"),316(pl.col("l") <= rhs).alias("le"),317(pl.col("l") > rhs).alias("gt"),318(pl.col("l") >= rhs).alias("ge"),319(pl.col("l").eq_missing(rhs)).alias("eq_missing"),320(pl.col("l").ne_missing(rhs)).alias("ne_missing"),321)322323ans_correct_dict = {324"eq": [ref and ref == "="], # "ref and X" propagates ref is None325"ne": [ref and ref != "="],326"lt": [ref and ref == "<"],327"le": [ref and (ref == "<" or ref == "=")],328"gt": [ref and ref == ">"],329"ge": [ref and (ref == ">" or ref == "=")],330"eq_missing": [refmiss == "="],331"ne_missing": [refmiss != "="],332}333ans_correct = pl.DataFrame(334ans_correct_dict, schema=dict.fromkeys(ans_correct_dict, pl.Boolean)335)336337assert_frame_equal(ans_first[:1], ans_correct)338assert_frame_equal(ans_scalar[:1], ans_correct)339340341INTERESTING_FLOAT_VALUES = [3420.0,343-0.0,344-1.0,3451.0,346-float("nan"),347float("nan"),348-float("inf"),349float("inf"),350None,351]352353354@pytest.mark.slow355@pytest.mark.parametrize("lhs", INTERESTING_FLOAT_VALUES)356@pytest.mark.parametrize("rhs", INTERESTING_FLOAT_VALUES)357def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> None:358verify_total_ordering(lhs, rhs, 0.0, pl.Float32, pl.Float32)359verify_total_ordering(lhs, rhs, 0.0, pl.Float64, pl.Float32)360context: pytest.WarningsRecorder | ContextManager[None] = (361pytest.warns(362UserWarning,363match=r"Consider using `\.is_null\(\)` or `\.is_not_null\(\)`",364)365if rhs is None366else nullcontext()367)368with context:369verify_total_ordering_broadcast(lhs, rhs, 0.0, pl.Float32, pl.Float32)370verify_total_ordering_broadcast(lhs, rhs, 0.0, pl.Float64, pl.Float32)371372373INTERESTING_STRING_VALUES = [374"",375"foo",376"bar",377"fooo",378"fooooooooooo",379"foooooooooooo",380"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooom",381"foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo",382"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooop",383None,384]385386387@pytest.mark.slow388@pytest.mark.parametrize("lhs", INTERESTING_STRING_VALUES)389@pytest.mark.parametrize("rhs", INTERESTING_STRING_VALUES)390def test_total_ordering_string_series(lhs: str | None, rhs: str | None) -> None:391verify_total_ordering(lhs, rhs, "", pl.String, pl.String)392context: pytest.WarningsRecorder | ContextManager[None] = (393pytest.warns(394UserWarning,395match=r"Consider using `\.is_null\(\)` or `\.is_not_null\(\)`",396)397if rhs is None398else nullcontext()399)400with context:401verify_total_ordering_broadcast(lhs, rhs, "", pl.String, pl.String)402403404@pytest.mark.slow405@pytest.mark.parametrize("lhs", INTERESTING_STRING_VALUES)406@pytest.mark.parametrize("rhs", INTERESTING_STRING_VALUES)407@pytest.mark.parametrize("fresh_cat", [False, True])408def test_total_ordering_cat_series(409lhs: str | None, rhs: str | None, fresh_cat: bool410) -> None:411if fresh_cat:412c = [pl.Categorical(pl.Categories.random()) for _ in range(6)]413else:414c = [pl.Categorical() for _ in range(6)]415verify_total_ordering(lhs, rhs, "", c[0], c[0])416verify_total_ordering(lhs, rhs, "", pl.String, c[1])417verify_total_ordering(lhs, rhs, "", c[2], pl.String)418context: pytest.WarningsRecorder | ContextManager[None] = (419pytest.warns(420UserWarning,421match=r"Consider using `\.is_null\(\)` or `\.is_not_null\(\)`",422)423if rhs is None424else nullcontext()425)426with context:427verify_total_ordering_broadcast(lhs, rhs, "", c[3], c[3])428verify_total_ordering_broadcast(lhs, rhs, "", pl.String, c[4])429verify_total_ordering_broadcast(lhs, rhs, "", c[5], pl.String)430431432@pytest.mark.slow433@pytest.mark.parametrize("str_lhs", INTERESTING_STRING_VALUES)434@pytest.mark.parametrize("str_rhs", INTERESTING_STRING_VALUES)435def test_total_ordering_binary_series(str_lhs: str | None, str_rhs: str | None) -> None:436lhs = None if str_lhs is None else str_lhs.encode("utf-8")437rhs = None if str_rhs is None else str_rhs.encode("utf-8")438verify_total_ordering(lhs, rhs, b"", pl.Binary, pl.Binary)439context: pytest.WarningsRecorder | ContextManager[None] = (440pytest.warns(441UserWarning,442match=r"Consider using `\.is_null\(\)` or `\.is_not_null\(\)`",443)444if rhs is None445else nullcontext()446)447with context:448verify_total_ordering_broadcast(lhs, rhs, b"", pl.Binary, pl.Binary)449450451@pytest.mark.parametrize("lhs", [None, False, True])452@pytest.mark.parametrize("rhs", [None, False, True])453def test_total_ordering_bool_series(lhs: bool | None, rhs: bool | None) -> None:454verify_total_ordering(lhs, rhs, False, pl.Boolean, pl.Boolean)455context: pytest.WarningsRecorder | ContextManager[None] = (456pytest.warns(457UserWarning,458match=r"Consider using `\.is_null\(\)` or `\.is_not_null\(\)`",459)460if rhs is None461else nullcontext()462)463with context:464verify_total_ordering_broadcast(lhs, rhs, False, pl.Boolean, pl.Boolean)465466467def test_cat_compare_with_bool() -> None:468data = pl.DataFrame([pl.Series("col1", ["a", "b"], dtype=pl.Categorical)])469470with pytest.raises(ComputeError, match="cannot compare categorical with bool"):471data.filter(pl.col("col1") == True) # noqa: E712472473474def test_schema_ne_missing_9256() -> None:475df = pl.DataFrame({"a": [0, 1, None], "b": [True, False, True]})476477assert df.select(pl.col("a").ne_missing(0).or_(pl.col("b")))["a"].all()478479480def test_nested_binary_literal_super_type_12227() -> None:481# The `.alias` is important here to trigger the bug.482result = pl.select(x=1).select((pl.lit(0) + ((pl.col("x") > 0) * 0.1)).alias("x"))483assert result.item() == 0.1484485result = pl.select((pl.lit(0) + (pl.lit(0) == pl.lit(0)) * pl.lit(0.1)) + pl.lit(0))486assert result.item() == 0.1487488489def test_struct_broadcasting_comparison() -> None:490df = pl.DataFrame({"foo": [{"a": 1}, {"a": 2}, {"a": 1}]})491assert df.select(eq=pl.col.foo == pl.col.foo.last()).to_dict(as_series=False) == {492"eq": [True, False, True]493}494495496@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 1)])497def test_compare_list_broadcast_empty_first_chunk_20165(dtype: pl.DataType) -> None:498s = pl.concat(2 * [pl.Series([[1]], dtype=dtype)]).filter([False, True])499500assert s.len() == 1501assert s.n_chunks() == 2502503assert_series_equal(504pl.select(pl.lit(pl.Series([[1], [2]]), dtype=dtype) == pl.lit(s)).to_series(),505pl.Series([True, False]),506)507508509def test_date_duration_comparison_error_25517() -> None:510date = pl.Series("date", [1], pl.Date)511duration = pl.Series("duration", [1], pl.Duration("ns"))512513with pytest.raises(ComputeError, match="cannot compare date with duration"):514_ = date > duration515516with pytest.raises(ComputeError, match="cannot compare date with duration"):517_ = duration > date518519with pytest.raises(ComputeError, match="cannot compare date with duration"):520_ = date == duration521522523