Path: blob/main/py-polars/tests/unit/testing/test_assert_frame_equal.py
8416 views
from __future__ import annotations12import math3from typing import Any45import pytest6from hypothesis import given78import polars as pl9from polars.exceptions import InvalidOperationError10from polars.testing import assert_frame_equal, assert_frame_not_equal11from polars.testing.parametric import dataframes1213nan = float("nan")14pytest_plugins = ["pytester"]151617@given(df=dataframes())18def test_equal(df: pl.DataFrame) -> None:19assert_frame_equal(df, df.clone(), check_exact=True)202122@pytest.mark.parametrize(23("df1", "df2", "kwargs"),24[25pytest.param(26pl.DataFrame({"a": [0.2, 0.3]}),27pl.DataFrame({"a": [0.2, 0.3]}),28{"abs_tol": 1e-15},29id="equal_floats_low_abs_tol",30),31pytest.param(32pl.DataFrame({"a": [0.2, 0.3]}),33pl.DataFrame({"a": [0.2, 0.3000000000000001]}),34{"abs_tol": 1e-15},35id="approx_equal_float_low_abs_tol",36),37pytest.param(38pl.DataFrame({"a": [0.2, 0.3]}),39pl.DataFrame({"a": [0.2, 0.31]}),40{"abs_tol": 0.1},41id="approx_equal_float_high_abs_tol",42),43pytest.param(44pl.DataFrame({"a": [0.2, 1.3]}),45pl.DataFrame({"a": [0.2, 0.9]}),46{"abs_tol": 1},47id="approx_equal_float_integer_abs_tol",48),49pytest.param(50pl.DataFrame({"a": [0.0, 1.0, 2.0]}, schema={"a": pl.Float64}),51pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}),52{"check_dtypes": False},53id="equal_int_float_integer_no_check_dtype",54),55pytest.param(56pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float64}),57pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float32}),58{"check_dtypes": False},59id="equal_int_float_integer_no_check_dtype",60),61pytest.param(62pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}),63pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}),64{},65id="equal_int",66),67pytest.param(68pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.String}),69pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.String}),70{},71id="equal_str",72),73pytest.param(74pl.DataFrame({"a": [[0.2, 0.3]]}),75pl.DataFrame({"a": [[0.2, 0.300001]]}),76{"abs_tol": 1e-5},77id="list_of_float_low_abs_tol",78),79pytest.param(80pl.DataFrame({"a": [[0.2, 0.3]]}),81pl.DataFrame({"a": [[0.2, 0.31]]}),82{"abs_tol": 0.1},83id="list_of_float_high_abs_tol",84),85pytest.param(86pl.DataFrame({"a": [[0.2, 1.3]]}),87pl.DataFrame({"a": [[0.2, 0.9]]}),88{"abs_tol": 1},89id="list_of_float_integer_abs_tol",90),91pytest.param(92pl.DataFrame({"a": [[0.2, 0.3]]}),93pl.DataFrame({"a": [[0.2, 0.300000001]]}),94{"rel_tol": 1e-5},95id="list_of_float_low_rel_tol",96),97pytest.param(98pl.DataFrame({"a": [[0.2, 0.3]]}),99pl.DataFrame({"a": [[0.2, 0.301]]}),100{"rel_tol": 0.1},101id="list_of_float_high_rel_tol",102),103pytest.param(104pl.DataFrame({"a": [[0.2, 1.3]]}),105pl.DataFrame({"a": [[0.2, 0.9]]}),106{"rel_tol": 1},107id="list_of_float_integer_rel_tol",108),109pytest.param(110pl.DataFrame({"a": [[None, 1.3]]}),111pl.DataFrame({"a": [[None, 0.9]]}),112{"rel_tol": 1},113id="list_of_none_and_float_integer_rel_tol",114),115pytest.param(116pl.DataFrame({"a": [[[0.2, 3.0]]]}),117pl.DataFrame({"a": [[[0.2, 3.00000001]]]}),118{"abs_tol": 0.1},119id="nested_list_of_float_abs_tol_high",120),121],122)123def test_assert_frame_equal_passes_assertion(124df1: pl.DataFrame,125df2: pl.DataFrame,126kwargs: dict[str, Any],127) -> None:128assert_frame_equal(df1, df2, **kwargs)129with pytest.raises(AssertionError):130assert_frame_not_equal(df1, df2, **kwargs)131132133@pytest.mark.parametrize(134("df1", "df2", "kwargs"),135[136pytest.param(137pl.DataFrame({"a": [[0.2, 0.3]]}),138pl.DataFrame({"a": [[0.2, 0.3, 0.4]]}),139{},140id="list_of_float_different_lengths",141),142pytest.param(143pl.DataFrame({"a": [[0.2, 0.3]]}),144pl.DataFrame({"a": [[0.2, 0.3000000000000001]]}),145{"check_exact": True},146id="list_of_float_check_exact",147),148pytest.param(149pl.DataFrame({"a": [[0.2, 0.3]]}),150pl.DataFrame({"a": [[0.2, 0.300001]]}),151{"abs_tol": 1e-15, "rel_tol": 0},152id="list_of_float_too_low_abs_tol",153),154pytest.param(155pl.DataFrame({"a": [[0.2, 0.3]]}),156pl.DataFrame({"a": [[0.2, 0.30000001]]}),157{"abs_tol": -1, "rel_tol": 0},158id="list_of_float_negative_abs_tol",159),160pytest.param(161pl.DataFrame({"a": [[2.0, 3.0]]}),162pl.DataFrame({"a": [[2, 3]]}),163{"check_exact": False, "check_dtypes": True},164id="list_of_float_list_of_int_check_dtype_true",165),166pytest.param(167pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}),168pl.DataFrame({"a": [[[0.2, math.nan, 3.11]]]}),169{"abs_tol": 0.1, "rel_tol": 0},170id="nested_list_of_float_and_nan_abs_tol_high",171),172pytest.param(173pl.DataFrame({"a": [[[[0.2, 3.0]]]]}),174pl.DataFrame({"a": [[[[0.2, 3.11]]]]}),175{"abs_tol": 0.1, "rel_tol": 0},176id="double_nested_list_of_float_abs_tol_high",177),178pytest.param(179pl.DataFrame({"a": [[[[[0.2, 3.0]]]]]}),180pl.DataFrame({"a": [[[[[0.2, 3.11]]]]]}),181{"abs_tol": 0.1, "rel_tol": 0},182id="triple_nested_list_of_float_abs_tol_high",183),184],185)186def test_assert_frame_equal_raises_assertion_error(187df1: pl.DataFrame,188df2: pl.DataFrame,189kwargs: dict[str, Any],190) -> None:191with pytest.raises(AssertionError):192assert_frame_equal(df1, df2, **kwargs)193assert_frame_not_equal(df1, df2, **kwargs)194195196def test_compare_frame_equal_nans() -> None:197df1 = pl.DataFrame(198data={"x": [1.0, nan], "y": [nan, 2.0]},199schema=[("x", pl.Float32), ("y", pl.Float64)],200)201assert_frame_equal(df1, df1, check_exact=True)202203df2 = pl.DataFrame(204data={"x": [1.0, nan], "y": [None, 2.0]},205schema=[("x", pl.Float32), ("y", pl.Float64)],206)207assert_frame_not_equal(df1, df2)208with pytest.raises(AssertionError, match='value mismatch for column "y"'):209assert_frame_equal(df1, df2, check_exact=True)210211212def test_compare_frame_equal_nested_nans() -> None:213# list dtype214df1 = pl.DataFrame(215data={"x": [[1.0, nan]], "y": [[nan, 2.0]]},216schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))],217)218assert_frame_equal(df1, df1, check_exact=True)219220df2 = pl.DataFrame(221data={"x": [[1.0, nan]], "y": [[None, 2.0]]},222schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))],223)224assert_frame_not_equal(df1, df2)225with pytest.raises(AssertionError, match='value mismatch for column "y"'):226assert_frame_equal(df1, df2, check_exact=True)227228# struct dtype229df3 = pl.from_dicts(230[231{232"id": 1,233"struct": [234{"x": "text", "y": [0.0, nan]},235{"x": "text", "y": [0.0, nan]},236],237},238{239"id": 2,240"struct": [241{"x": "text", "y": [1]},242{"x": "text", "y": [1]},243],244},245]246)247df4 = pl.from_dicts(248[249{250"id": 1,251"struct": [252{"x": "text", "y": [0.0, nan], "z": ["$"]},253{"x": "text", "y": [0.0, nan], "z": ["$"]},254],255},256{257"id": 2,258"struct": [259{"x": "text", "y": [nan, 1.0], "z": ["!"]},260{"x": "text", "y": [nan, 1.0], "z": ["?"]},261],262},263]264)265266assert_frame_equal(df3, df3)267assert_frame_equal(df4, df4)268269assert_frame_not_equal(df3, df4)270for check_dtype in (True, False):271with pytest.raises(AssertionError, match=r"mismatch|different"):272assert_frame_equal(df3, df4, check_dtypes=check_dtype)273274275def test_assert_frame_equal_pass() -> None:276df1 = pl.DataFrame({"a": [1, 2]})277df2 = pl.DataFrame({"a": [1, 2]})278assert_frame_equal(df1, df2)279280281@pytest.mark.parametrize(282"assert_function",283[assert_frame_equal, assert_frame_not_equal],284)285def test_assert_frame_equal_types(assert_function: Any) -> None:286df1 = pl.DataFrame({"a": [1, 2]})287srs1 = pl.Series(values=[1, 2], name="a")288with pytest.raises(289AssertionError, match=r"inputs are different \(unexpected input types\)"290):291assert_function(df1, srs1)292293294def test_assert_frame_equal_length_mismatch() -> None:295df1 = pl.DataFrame({"a": [1, 2]})296df2 = pl.DataFrame({"a": [1, 2, 3]})297with pytest.raises(298AssertionError,299match=r"DataFrames are different \(height \(row count\) mismatch\)",300):301assert_frame_equal(df1, df2)302assert_frame_not_equal(df1, df2)303304305def test_assert_frame_equal_column_mismatch() -> None:306df1 = pl.DataFrame({"a": [1, 2]})307df2 = pl.DataFrame({"b": [1, 2]})308with pytest.raises(309AssertionError,310match=r'DataFrames are different \(columns mismatch: \["a"\] in left, but not in right\)',311):312assert_frame_equal(df1, df2)313assert_frame_not_equal(df1, df2)314315316def test_assert_frame_equal_column_mismatch2() -> None:317df1 = pl.LazyFrame({"a": [1, 2]})318df2 = pl.LazyFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})319with pytest.raises(320AssertionError,321match=r"columns mismatch.*in right.*but not in left",322):323assert_frame_equal(df1, df2)324assert_frame_not_equal(df1, df2)325326327def test_assert_frame_equal_column_mismatch_order() -> None:328df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]})329df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]})330with pytest.raises(AssertionError, match="columns are not in the same order"):331assert_frame_equal(df1, df2)332333assert_frame_equal(df1, df2, check_column_order=False)334assert_frame_not_equal(df1, df2)335336337def test_assert_frame_equal_check_row_order() -> None:338df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]})339df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]})340341with pytest.raises(AssertionError, match='value mismatch for column "a"'):342assert_frame_equal(df1, df2)343344assert_frame_equal(df1, df2, check_row_order=False)345assert_frame_not_equal(df1, df2)346347348def test_assert_frame_equal_check_row_col_order() -> None:349df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]})350df2 = pl.DataFrame({"b": [3, 4], "a": [2, 1]})351352with pytest.raises(AssertionError, match="columns are not in the same order"):353assert_frame_equal(df1, df2, check_row_order=False)354355assert_frame_equal(df1, df2, check_row_order=False, check_column_order=False)356assert_frame_not_equal(df1, df2)357358359@pytest.mark.parametrize(360"assert_function",361[assert_frame_equal, assert_frame_not_equal],362)363def test_assert_frame_equal_check_row_order_unsortable(assert_function: Any) -> None:364df1 = pl.DataFrame({"a": [object(), object()], "b": [3, 4]})365df2 = pl.DataFrame({"a": [object(), object()], "b": [4, 3]})366with pytest.raises(367InvalidOperationError,368match="`arg_sort_multiple` operation not supported for dtype `object`",369):370assert_function(df1, df2, check_row_order=False)371372373def test_assert_frame_equal_dtypes_mismatch() -> None:374data = {"a": [1, 2], "b": [3, 4]}375df1 = pl.DataFrame(data, schema={"a": pl.Int8, "b": pl.Int16})376df2 = pl.DataFrame(data, schema={"b": pl.Int16, "a": pl.Int16})377378with pytest.raises(AssertionError, match="dtypes do not match"):379assert_frame_equal(df1, df2, check_column_order=False)380381assert_frame_not_equal(df1, df2, check_column_order=False)382assert_frame_not_equal(df1, df2)383384385def test_assert_frame_not_equal() -> None:386df = pl.DataFrame({"a": [1, 2]})387with pytest.raises(AssertionError, match="DataFrames are equal"):388assert_frame_not_equal(df, df)389lf = df.lazy()390with pytest.raises(AssertionError, match="LazyFrames are equal"):391assert_frame_not_equal(lf, lf)392393394def test_assert_frame_equal_check_dtype_deprecated() -> None:395df1 = pl.DataFrame({"a": [1, 2]})396df2 = pl.DataFrame({"a": [1.0, 2.0]})397df3 = pl.DataFrame({"a": [2, 1]})398399with pytest.deprecated_call():400assert_frame_equal(df1, df2, check_dtype=False) # type: ignore[call-arg]401402with pytest.deprecated_call():403assert_frame_not_equal(df1, df3, check_dtype=False) # type: ignore[call-arg]404405406def test_assert_dataframe_equal_all_nulls_passes_when_ignoring_dtypes() -> None:407x = pl.from_dict({"A": [None, None, None]})408y = pl.from_dict(409{"A": [None, None, None]}, schema_overrides={"A": pl.List(pl.Float64())}410)411412assert_frame_equal(x, y, check_dtypes=False)413414415def test_assert_dataframe_equal_all_nulls_fails_when_checking_dtypes() -> None:416x = pl.from_dict({"A": [None, None, None]})417y = pl.from_dict(418{"A": [None, None, None]}, schema_overrides={"A": pl.List(pl.Float64())}419)420421with pytest.raises(AssertionError, match="dtypes do not match"):422assert_frame_equal(x, y, check_dtypes=True)423424425def test_tracebackhide(testdir: pytest.Testdir) -> None:426testdir.makefile(427".py",428test_path="""\429import polars as pl430from polars.testing import assert_frame_equal, assert_frame_not_equal431432def test_frame_equal_fail():433df1 = pl.DataFrame({"a": [1, 2]})434df2 = pl.DataFrame({"a": [1, 3]})435assert_frame_equal(df1, df2)436437def test_frame_not_equal_fail():438df1 = pl.DataFrame({"a": [1, 2]})439df2 = pl.DataFrame({"a": [1, 2]})440assert_frame_not_equal(df1, df2)441442def test_frame_data_type_fail():443df1 = pl.DataFrame({"a": [1, 2]})444df2 = {"a": [1, 2]}445assert_frame_equal(df1, df2)446447def test_frame_schema_fail():448df1 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int64})449df2 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int32})450assert_frame_equal(df1, df2)451""",452)453result = testdir.runpytest()454result.assert_outcomes(passed=0, failed=4)455stdout = "\n".join(result.outlines)456457assert "polars/py-polars/polars/testing" not in stdout458459# The above should catch any polars testing functions that appear in the460# stack trace. But we keep the following checks (for specific function461# names) just to double-check.462463assert "def assert_frame_equal" not in stdout464assert "def assert_frame_not_equal" not in stdout465assert "def _assert_correct_input_type" not in stdout466467assert "def assert_series_equal" not in stdout468assert "def assert_series_not_equal" not in stdout469470# Make sure the tests are failing for the expected reason (e.g. not because471# an import is missing or something like that):472473assert (474'AssertionError: DataFrames are different (value mismatch for column "a")'475in stdout476)477assert "AssertionError: DataFrames are equal" in stdout478assert "AssertionError: inputs are different (unexpected input types)" in stdout479assert "AssertionError: DataFrames are different (dtypes do not match)" in stdout480481482