Path: blob/main/py-polars/tests/unit/series/test_equals.py
8430 views
from collections.abc import Callable1from datetime import datetime23import pytest45import polars as pl6from polars.testing import assert_series_equal789def test_equals() -> None:10s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)11s2 = pl.Series("a", [1, 2, None], pl.Int64)1213assert s1.equals(s2) is True14assert s1.equals(s2, check_dtypes=True) is False15assert s1.equals(s2, null_equal=False) is False1617df = pl.DataFrame(18{"dtm": [datetime(2222, 2, 22, 22, 22, 22)]},19schema_overrides={"dtm": pl.Datetime(time_zone="UTC")},20).with_columns(21s3=pl.col("dtm").dt.convert_time_zone("Europe/London"),22s4=pl.col("dtm").dt.convert_time_zone("Asia/Tokyo"),23)24s3 = df["s3"].rename("b")25s4 = df["s4"].rename("b")2627assert s3.equals(s4) is False28assert s3.equals(s4, check_dtypes=True) is False29assert s3.equals(s4, null_equal=False) is False30assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True3132with pytest.raises(33TypeError,34match=r"expected `other` to be a 'Series'.* not 'DataFrame'",35):36s1.equals(pl.DataFrame(s2), check_names=False) # type: ignore[arg-type]3738with pytest.raises(39TypeError,40match=r"expected `other` to be a 'Series'.* not 'LazyFrame'",41):42s1.equals(pl.DataFrame(s2).lazy(), check_names=False) # type: ignore[arg-type]4344s5 = pl.Series("a", [1, 2, 3])4546class DummySeriesSubclass(pl.Series):47pass4849assert s5.equals(DummySeriesSubclass(s5)) is True505152def test_series_equals_check_names() -> None:53s1 = pl.Series("foo", [1, 2, 3])54s2 = pl.Series("bar", [1, 2, 3])55assert s1.equals(s2) is True56assert s1.equals(s2, check_names=True) is False575859def test_eq_list_cmp_list() -> None:60s = pl.Series([[1], [1, 2]])61result = s == [1, 2]62expected = pl.Series([False, True])63assert_series_equal(result, expected)646566def test_eq_list_cmp_int() -> None:67s = pl.Series([[1], [1, 2]])68with pytest.raises(69NotImplementedError,70match=r"Series of type List\(Int64\) does not have eq operator",71):72s == 1 # noqa: B015737475def test_eq_array_cmp_list() -> None:76s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))77result = s == [1, 2]78expected = pl.Series([False, True])79assert_series_equal(result, expected)808182def test_eq_array_cmp_int() -> None:83s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))84with pytest.raises(85NotImplementedError,86match=r"Series of type Array\(Int16, shape=\(2,\)\) does not have eq operator",87):88s == 1 # noqa: B015899091def test_eq_list() -> None:92s = pl.Series([1, 1])9394result = s == [1, 2]95expected = pl.Series([True, False])96assert_series_equal(result, expected)9798result = s == 199expected = pl.Series([True, True])100assert_series_equal(result, expected)101102103def test_eq_missing_expr() -> None:104s = pl.Series([1, None])105result = s.eq_missing(pl.lit(1))106107assert isinstance(result, pl.Expr)108result_evaluated = pl.select(result).to_series()109expected = pl.Series([True, False])110assert_series_equal(result_evaluated, expected)111112113def test_ne_missing_expr() -> None:114s = pl.Series([1, None])115result = s.ne_missing(pl.lit(1))116117assert isinstance(result, pl.Expr)118result_evaluated = pl.select(result).to_series()119expected = pl.Series([False, True])120assert_series_equal(result_evaluated, expected)121122123def test_series_equals_strict_deprecated() -> None:124s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)125s2 = pl.Series("a", [1, 2, None], pl.Int64)126with pytest.deprecated_call():127assert not s1.equals(s2, strict=True) # type: ignore[call-arg]128129130@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])131@pytest.mark.parametrize(132("cmp_eq", "cmp_ne"),133[134# We parametrize the comparison sides as the impl looks like this:135# match (left.len(), right.len()) {136# (1, _) => ...,137# (_, 1) => ...,138# (_, _) => ...,139# }140(pl.Series.eq, pl.Series.ne),141(142lambda a, b: pl.Series.eq(b, a),143lambda a, b: pl.Series.ne(b, a),144),145],146)147def test_eq_lists_arrays(148dtype: pl.DataType,149cmp_eq: Callable[[pl.Series, pl.Series], pl.Series],150cmp_ne: Callable[[pl.Series, pl.Series], pl.Series],151) -> None:152# Broadcast NULL153assert_series_equal(154cmp_eq(155pl.Series([None], dtype=dtype),156pl.Series([None, [1, None], [1, 1]], dtype=dtype),157),158pl.Series([None, None, None], dtype=pl.Boolean),159)160161assert_series_equal(162cmp_ne(163pl.Series([None], dtype=dtype),164pl.Series([None, [1, None], [1, 1]], dtype=dtype),165),166pl.Series([None, None, None], dtype=pl.Boolean),167)168169# Non-broadcast full-NULL170assert_series_equal(171cmp_eq(172pl.Series(3 * [None], dtype=dtype),173pl.Series([None, [1, None], [1, 1]], dtype=dtype),174),175pl.Series([None, None, None], dtype=pl.Boolean),176)177178assert_series_equal(179cmp_ne(180pl.Series(3 * [None], dtype=dtype),181pl.Series([None, [1, None], [1, 1]], dtype=dtype),182),183pl.Series([None, None, None], dtype=pl.Boolean),184)185186# Broadcast valid187assert_series_equal(188cmp_eq(189pl.Series([[1, None]], dtype=dtype),190pl.Series([None, [1, None], [1, 1]], dtype=dtype),191),192pl.Series([None, True, False], dtype=pl.Boolean),193)194195assert_series_equal(196cmp_ne(197pl.Series([[1, None]], dtype=dtype),198pl.Series([None, [1, None], [1, 1]], dtype=dtype),199),200pl.Series([None, False, True], dtype=pl.Boolean),201)202203# Non-broadcast mixed204assert_series_equal(205cmp_eq(206pl.Series([None, [1, 1], [1, 1]], dtype=dtype),207pl.Series([None, [1, None], [1, 1]], dtype=dtype),208),209pl.Series([None, False, True], dtype=pl.Boolean),210)211212assert_series_equal(213cmp_ne(214pl.Series([None, [1, 1], [1, 1]], dtype=dtype),215pl.Series([None, [1, None], [1, 1]], dtype=dtype),216),217pl.Series([None, True, False], dtype=pl.Boolean),218)219220221@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])222@pytest.mark.parametrize(223("cmp_eq_missing", "cmp_ne_missing"),224[225(pl.Series.eq_missing, pl.Series.ne_missing),226(227lambda a, b: pl.Series.eq_missing(b, a),228lambda a, b: pl.Series.ne_missing(b, a),229),230],231)232def test_eq_missing_lists_arrays_19153(233dtype: pl.DataType,234cmp_eq_missing: Callable[[pl.Series, pl.Series], pl.Series],235cmp_ne_missing: Callable[[pl.Series, pl.Series], pl.Series],236) -> None:237def assert_series_equal_wrap(238left: pl.Series,239right: pl.Series,240) -> None:241# `assert_series_equal` also uses `ne_missing` underneath so we have242# some extra checks here to be sure.243assert_series_equal(left, right)244assert left.to_list() == right.to_list()245assert left.null_count() == 0246assert right.null_count() == 0247248# Broadcast NULL249assert_series_equal_wrap(250cmp_eq_missing(251pl.Series([None], dtype=dtype),252pl.Series([None, [1, None], [1, 1]], dtype=dtype),253),254pl.Series([True, False, False]),255)256257assert_series_equal_wrap(258cmp_ne_missing(259pl.Series([None], dtype=dtype),260pl.Series([None, [1, None], [1, 1]], dtype=dtype),261),262pl.Series([False, True, True]),263)264265# Non-broadcast full-NULL266assert_series_equal_wrap(267cmp_eq_missing(268pl.Series(3 * [None], dtype=dtype),269pl.Series([None, [1, None], [1, 1]], dtype=dtype),270),271pl.Series([True, False, False]),272)273274assert_series_equal_wrap(275cmp_ne_missing(276pl.Series(3 * [None], dtype=dtype),277pl.Series([None, [1, None], [1, 1]], dtype=dtype),278),279pl.Series([False, True, True]),280)281282# Broadcast valid283assert_series_equal_wrap(284cmp_eq_missing(285pl.Series([[1, None]], dtype=dtype),286pl.Series([None, [1, None], [1, 1]], dtype=dtype),287),288pl.Series([False, True, False]),289)290291assert_series_equal_wrap(292cmp_ne_missing(293pl.Series([[1, None]], dtype=dtype),294pl.Series([None, [1, None], [1, 1]], dtype=dtype),295),296pl.Series([True, False, True]),297)298299# Non-broadcast mixed300assert_series_equal_wrap(301cmp_eq_missing(302pl.Series([None, [1, 1], [1, 1]], dtype=dtype),303pl.Series([None, [1, None], [1, 1]], dtype=dtype),304),305pl.Series([True, False, True]),306)307308assert_series_equal_wrap(309cmp_ne_missing(310pl.Series([None, [1, 1], [1, 1]], dtype=dtype),311pl.Series([None, [1, None], [1, 1]], dtype=dtype),312),313pl.Series([False, True, False]),314)315316317def test_equals_nested_null_categorical_14875() -> None:318dtype = pl.List(pl.Struct({"cat": pl.Categorical}))319s = pl.Series([[{"cat": None}]], dtype=dtype)320assert s.equals(s)321322323