Path: blob/main/py-polars/tests/unit/series/test_equals.py
6939 views
from datetime import datetime1from typing import Callable23import pytest45import polars as pl6from polars.testing import assert_series_equal789def test_equals() -> None:10s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)11s2 = pl.Series("a", [1, 2, None], pl.Int64)1213assert s1.equals(s2) is True14assert s1.equals(s2, check_dtypes=True) is False15assert s1.equals(s2, null_equal=False) is False1617df = pl.DataFrame(18{"dtm": [datetime(2222, 2, 22, 22, 22, 22)]},19schema_overrides={"dtm": pl.Datetime(time_zone="UTC")},20).with_columns(21s3=pl.col("dtm").dt.convert_time_zone("Europe/London"),22s4=pl.col("dtm").dt.convert_time_zone("Asia/Tokyo"),23)24s3 = df["s3"].rename("b")25s4 = df["s4"].rename("b")2627assert s3.equals(s4) is False28assert s3.equals(s4, check_dtypes=True) is False29assert s3.equals(s4, null_equal=False) is False30assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True3132with pytest.raises(33TypeError,34match="expected `other` to be a 'Series'.* not 'DataFrame'",35):36s1.equals(pl.DataFrame(s2), check_names=False) # type: ignore[arg-type]3738with pytest.raises(39TypeError,40match="expected `other` to be a 'Series'.* not 'LazyFrame'",41):42s1.equals(pl.DataFrame(s2).lazy(), check_names=False) # type: ignore[arg-type]4344s5 = pl.Series("a", [1, 2, 3])4546class DummySeriesSubclass(pl.Series):47pass4849assert s5.equals(DummySeriesSubclass(s5)) is True505152def test_series_equals_check_names() -> None:53s1 = pl.Series("foo", [1, 2, 3])54s2 = pl.Series("bar", [1, 2, 3])55assert s1.equals(s2) is True56assert s1.equals(s2, check_names=True) is False575859def test_eq_list_cmp_list() -> None:60s = pl.Series([[1], [1, 2]])61result = s == [1, 2]62expected = pl.Series([False, True])63assert_series_equal(result, expected)646566def test_eq_list_cmp_int() -> None:67s = pl.Series([[1], [1, 2]])68with pytest.raises(69NotImplementedError,70match=r"Series of type List\(Int64\) does not have eq operator",71):72s == 1 # noqa: B015737475def test_eq_array_cmp_list() -> None:76s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))77result = s == [1, 2]78expected = pl.Series([False, True])79assert_series_equal(result, expected)808182def test_eq_array_cmp_int() -> None:83s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))84with pytest.raises(85NotImplementedError,86match=r"Series of type Array\(Int16, shape=\(2,\)\) does not have eq operator",87):88s == 1 # noqa: B015899091def test_eq_list() -> None:92s = pl.Series([1, 1])9394result = s == [1, 2]95expected = pl.Series([True, False])96assert_series_equal(result, expected)9798result = s == 199expected = pl.Series([True, True])100assert_series_equal(result, expected)101102103def test_eq_missing_expr() -> None:104s = pl.Series([1, None])105result = s.eq_missing(pl.lit(1))106107assert isinstance(result, pl.Expr)108result_evaluated = pl.select(result).to_series()109expected = pl.Series([True, False])110assert_series_equal(result_evaluated, expected)111112113def test_ne_missing_expr() -> None:114s = pl.Series([1, None])115result = s.ne_missing(pl.lit(1))116117assert isinstance(result, pl.Expr)118result_evaluated = pl.select(result).to_series()119expected = pl.Series([False, True])120assert_series_equal(result_evaluated, expected)121122123def test_series_equals_strict_deprecated() -> None:124s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)125s2 = pl.Series("a", [1, 2, None], pl.Int64)126with pytest.deprecated_call():127assert not s1.equals(s2, strict=True) # type: ignore[call-arg]128129130@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])131@pytest.mark.parametrize(132("cmp_eq", "cmp_ne"),133[134# We parametrize the comparison sides as the impl looks like this:135# match (left.len(), right.len()) {136# (1, _) => ...,137# (_, 1) => ...,138# (_, _) => ...,139# }140(pl.Series.eq, pl.Series.ne),141(142lambda a, b: pl.Series.eq(b, a),143lambda a, b: pl.Series.ne(b, a),144),145],146)147def test_eq_lists_arrays(148dtype: pl.DataType,149cmp_eq: Callable[[pl.Series, pl.Series], pl.Series],150cmp_ne: Callable[[pl.Series, pl.Series], pl.Series],151) -> None:152# Broadcast NULL153assert_series_equal(154cmp_eq(155pl.Series([None], dtype=dtype),156pl.Series([None, [1, None], [1, 1]], dtype=dtype),157),158pl.Series([None, None, None], dtype=pl.Boolean),159)160161assert_series_equal(162cmp_ne(163pl.Series([None], dtype=dtype),164pl.Series([None, [1, None], [1, 1]], dtype=dtype),165),166pl.Series([None, None, None], dtype=pl.Boolean),167)168169# Non-broadcast full-NULL170assert_series_equal(171cmp_eq(172pl.Series(3 * [None], dtype=dtype),173pl.Series([None, [1, None], [1, 1]], dtype=dtype),174),175pl.Series([None, None, None], dtype=pl.Boolean),176)177178assert_series_equal(179cmp_ne(180pl.Series(3 * [None], dtype=dtype),181pl.Series([None, [1, None], [1, 1]], dtype=dtype),182),183pl.Series([None, None, None], dtype=pl.Boolean),184)185186# Broadcast valid187assert_series_equal(188cmp_eq(189pl.Series([[1, None]], dtype=dtype),190pl.Series([None, [1, None], [1, 1]], dtype=dtype),191),192pl.Series([None, True, False], dtype=pl.Boolean),193)194195assert_series_equal(196cmp_ne(197pl.Series([[1, None]], dtype=dtype),198pl.Series([None, [1, None], [1, 1]], dtype=dtype),199),200pl.Series([None, False, True], dtype=pl.Boolean),201)202203# Non-broadcast mixed204assert_series_equal(205cmp_eq(206pl.Series([None, [1, 1], [1, 1]], dtype=dtype),207pl.Series([None, [1, None], [1, 1]], dtype=dtype),208),209pl.Series([None, False, True], dtype=pl.Boolean),210)211212assert_series_equal(213cmp_ne(214pl.Series([None, [1, 1], [1, 1]], dtype=dtype),215pl.Series([None, [1, None], [1, 1]], dtype=dtype),216),217pl.Series([None, True, False], dtype=pl.Boolean),218)219220221@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])222@pytest.mark.parametrize(223("cmp_eq_missing", "cmp_ne_missing"),224[225(pl.Series.eq_missing, pl.Series.ne_missing),226(227lambda a, b: pl.Series.eq_missing(b, a),228lambda a, b: pl.Series.ne_missing(b, a),229),230],231)232def test_eq_missing_lists_arrays_19153(233dtype: pl.DataType,234cmp_eq_missing: Callable[[pl.Series, pl.Series], pl.Series],235cmp_ne_missing: Callable[[pl.Series, pl.Series], pl.Series],236) -> None:237def assert_series_equal(238left: pl.Series,239right: pl.Series,240*,241assert_series_equal_impl: Callable[[pl.Series, pl.Series], None] = globals()[242"assert_series_equal"243],244) -> None:245# `assert_series_equal` also uses `ne_missing` underneath so we have246# some extra checks here to be sure.247assert_series_equal_impl(left, right)248assert left.to_list() == right.to_list()249assert left.null_count() == 0250assert right.null_count() == 0251252# Broadcast NULL253assert_series_equal(254cmp_eq_missing(255pl.Series([None], dtype=dtype),256pl.Series([None, [1, None], [1, 1]], dtype=dtype),257),258pl.Series([True, False, False]),259)260261assert_series_equal(262cmp_ne_missing(263pl.Series([None], dtype=dtype),264pl.Series([None, [1, None], [1, 1]], dtype=dtype),265),266pl.Series([False, True, True]),267)268269# Non-broadcast full-NULL270assert_series_equal(271cmp_eq_missing(272pl.Series(3 * [None], dtype=dtype),273pl.Series([None, [1, None], [1, 1]], dtype=dtype),274),275pl.Series([True, False, False]),276)277278assert_series_equal(279cmp_ne_missing(280pl.Series(3 * [None], dtype=dtype),281pl.Series([None, [1, None], [1, 1]], dtype=dtype),282),283pl.Series([False, True, True]),284)285286# Broadcast valid287assert_series_equal(288cmp_eq_missing(289pl.Series([[1, None]], dtype=dtype),290pl.Series([None, [1, None], [1, 1]], dtype=dtype),291),292pl.Series([False, True, False]),293)294295assert_series_equal(296cmp_ne_missing(297pl.Series([[1, None]], dtype=dtype),298pl.Series([None, [1, None], [1, 1]], dtype=dtype),299),300pl.Series([True, False, True]),301)302303# Non-broadcast mixed304assert_series_equal(305cmp_eq_missing(306pl.Series([None, [1, 1], [1, 1]], dtype=dtype),307pl.Series([None, [1, None], [1, 1]], dtype=dtype),308),309pl.Series([True, False, True]),310)311312assert_series_equal(313cmp_ne_missing(314pl.Series([None, [1, 1], [1, 1]], dtype=dtype),315pl.Series([None, [1, None], [1, 1]], dtype=dtype),316),317pl.Series([False, True, False]),318)319320321def test_equals_nested_null_categorical_14875() -> None:322dtype = pl.List(pl.Struct({"cat": pl.Categorical}))323s = pl.Series([[{"cat": None}]], dtype=dtype)324assert s.equals(s)325326327