Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/testing/test_assert_frame_equal.py
8416 views
1
from __future__ import annotations
2
3
import math
4
from typing import Any
5
6
import pytest
7
from hypothesis import given
8
9
import polars as pl
10
from polars.exceptions import InvalidOperationError
11
from polars.testing import assert_frame_equal, assert_frame_not_equal
12
from polars.testing.parametric import dataframes
13
14
nan = float("nan")
15
pytest_plugins = ["pytester"]
16
17
18
@given(df=dataframes())
19
def test_equal(df: pl.DataFrame) -> None:
20
assert_frame_equal(df, df.clone(), check_exact=True)
21
22
23
@pytest.mark.parametrize(
24
("df1", "df2", "kwargs"),
25
[
26
pytest.param(
27
pl.DataFrame({"a": [0.2, 0.3]}),
28
pl.DataFrame({"a": [0.2, 0.3]}),
29
{"abs_tol": 1e-15},
30
id="equal_floats_low_abs_tol",
31
),
32
pytest.param(
33
pl.DataFrame({"a": [0.2, 0.3]}),
34
pl.DataFrame({"a": [0.2, 0.3000000000000001]}),
35
{"abs_tol": 1e-15},
36
id="approx_equal_float_low_abs_tol",
37
),
38
pytest.param(
39
pl.DataFrame({"a": [0.2, 0.3]}),
40
pl.DataFrame({"a": [0.2, 0.31]}),
41
{"abs_tol": 0.1},
42
id="approx_equal_float_high_abs_tol",
43
),
44
pytest.param(
45
pl.DataFrame({"a": [0.2, 1.3]}),
46
pl.DataFrame({"a": [0.2, 0.9]}),
47
{"abs_tol": 1},
48
id="approx_equal_float_integer_abs_tol",
49
),
50
pytest.param(
51
pl.DataFrame({"a": [0.0, 1.0, 2.0]}, schema={"a": pl.Float64}),
52
pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}),
53
{"check_dtypes": False},
54
id="equal_int_float_integer_no_check_dtype",
55
),
56
pytest.param(
57
pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float64}),
58
pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float32}),
59
{"check_dtypes": False},
60
id="equal_int_float_integer_no_check_dtype",
61
),
62
pytest.param(
63
pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}),
64
pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}),
65
{},
66
id="equal_int",
67
),
68
pytest.param(
69
pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.String}),
70
pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.String}),
71
{},
72
id="equal_str",
73
),
74
pytest.param(
75
pl.DataFrame({"a": [[0.2, 0.3]]}),
76
pl.DataFrame({"a": [[0.2, 0.300001]]}),
77
{"abs_tol": 1e-5},
78
id="list_of_float_low_abs_tol",
79
),
80
pytest.param(
81
pl.DataFrame({"a": [[0.2, 0.3]]}),
82
pl.DataFrame({"a": [[0.2, 0.31]]}),
83
{"abs_tol": 0.1},
84
id="list_of_float_high_abs_tol",
85
),
86
pytest.param(
87
pl.DataFrame({"a": [[0.2, 1.3]]}),
88
pl.DataFrame({"a": [[0.2, 0.9]]}),
89
{"abs_tol": 1},
90
id="list_of_float_integer_abs_tol",
91
),
92
pytest.param(
93
pl.DataFrame({"a": [[0.2, 0.3]]}),
94
pl.DataFrame({"a": [[0.2, 0.300000001]]}),
95
{"rel_tol": 1e-5},
96
id="list_of_float_low_rel_tol",
97
),
98
pytest.param(
99
pl.DataFrame({"a": [[0.2, 0.3]]}),
100
pl.DataFrame({"a": [[0.2, 0.301]]}),
101
{"rel_tol": 0.1},
102
id="list_of_float_high_rel_tol",
103
),
104
pytest.param(
105
pl.DataFrame({"a": [[0.2, 1.3]]}),
106
pl.DataFrame({"a": [[0.2, 0.9]]}),
107
{"rel_tol": 1},
108
id="list_of_float_integer_rel_tol",
109
),
110
pytest.param(
111
pl.DataFrame({"a": [[None, 1.3]]}),
112
pl.DataFrame({"a": [[None, 0.9]]}),
113
{"rel_tol": 1},
114
id="list_of_none_and_float_integer_rel_tol",
115
),
116
pytest.param(
117
pl.DataFrame({"a": [[[0.2, 3.0]]]}),
118
pl.DataFrame({"a": [[[0.2, 3.00000001]]]}),
119
{"abs_tol": 0.1},
120
id="nested_list_of_float_abs_tol_high",
121
),
122
],
123
)
124
def test_assert_frame_equal_passes_assertion(
125
df1: pl.DataFrame,
126
df2: pl.DataFrame,
127
kwargs: dict[str, Any],
128
) -> None:
129
assert_frame_equal(df1, df2, **kwargs)
130
with pytest.raises(AssertionError):
131
assert_frame_not_equal(df1, df2, **kwargs)
132
133
134
@pytest.mark.parametrize(
135
("df1", "df2", "kwargs"),
136
[
137
pytest.param(
138
pl.DataFrame({"a": [[0.2, 0.3]]}),
139
pl.DataFrame({"a": [[0.2, 0.3, 0.4]]}),
140
{},
141
id="list_of_float_different_lengths",
142
),
143
pytest.param(
144
pl.DataFrame({"a": [[0.2, 0.3]]}),
145
pl.DataFrame({"a": [[0.2, 0.3000000000000001]]}),
146
{"check_exact": True},
147
id="list_of_float_check_exact",
148
),
149
pytest.param(
150
pl.DataFrame({"a": [[0.2, 0.3]]}),
151
pl.DataFrame({"a": [[0.2, 0.300001]]}),
152
{"abs_tol": 1e-15, "rel_tol": 0},
153
id="list_of_float_too_low_abs_tol",
154
),
155
pytest.param(
156
pl.DataFrame({"a": [[0.2, 0.3]]}),
157
pl.DataFrame({"a": [[0.2, 0.30000001]]}),
158
{"abs_tol": -1, "rel_tol": 0},
159
id="list_of_float_negative_abs_tol",
160
),
161
pytest.param(
162
pl.DataFrame({"a": [[2.0, 3.0]]}),
163
pl.DataFrame({"a": [[2, 3]]}),
164
{"check_exact": False, "check_dtypes": True},
165
id="list_of_float_list_of_int_check_dtype_true",
166
),
167
pytest.param(
168
pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}),
169
pl.DataFrame({"a": [[[0.2, math.nan, 3.11]]]}),
170
{"abs_tol": 0.1, "rel_tol": 0},
171
id="nested_list_of_float_and_nan_abs_tol_high",
172
),
173
pytest.param(
174
pl.DataFrame({"a": [[[[0.2, 3.0]]]]}),
175
pl.DataFrame({"a": [[[[0.2, 3.11]]]]}),
176
{"abs_tol": 0.1, "rel_tol": 0},
177
id="double_nested_list_of_float_abs_tol_high",
178
),
179
pytest.param(
180
pl.DataFrame({"a": [[[[[0.2, 3.0]]]]]}),
181
pl.DataFrame({"a": [[[[[0.2, 3.11]]]]]}),
182
{"abs_tol": 0.1, "rel_tol": 0},
183
id="triple_nested_list_of_float_abs_tol_high",
184
),
185
],
186
)
187
def test_assert_frame_equal_raises_assertion_error(
188
df1: pl.DataFrame,
189
df2: pl.DataFrame,
190
kwargs: dict[str, Any],
191
) -> None:
192
with pytest.raises(AssertionError):
193
assert_frame_equal(df1, df2, **kwargs)
194
assert_frame_not_equal(df1, df2, **kwargs)
195
196
197
def test_compare_frame_equal_nans() -> None:
198
df1 = pl.DataFrame(
199
data={"x": [1.0, nan], "y": [nan, 2.0]},
200
schema=[("x", pl.Float32), ("y", pl.Float64)],
201
)
202
assert_frame_equal(df1, df1, check_exact=True)
203
204
df2 = pl.DataFrame(
205
data={"x": [1.0, nan], "y": [None, 2.0]},
206
schema=[("x", pl.Float32), ("y", pl.Float64)],
207
)
208
assert_frame_not_equal(df1, df2)
209
with pytest.raises(AssertionError, match='value mismatch for column "y"'):
210
assert_frame_equal(df1, df2, check_exact=True)
211
212
213
def test_compare_frame_equal_nested_nans() -> None:
214
# list dtype
215
df1 = pl.DataFrame(
216
data={"x": [[1.0, nan]], "y": [[nan, 2.0]]},
217
schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))],
218
)
219
assert_frame_equal(df1, df1, check_exact=True)
220
221
df2 = pl.DataFrame(
222
data={"x": [[1.0, nan]], "y": [[None, 2.0]]},
223
schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))],
224
)
225
assert_frame_not_equal(df1, df2)
226
with pytest.raises(AssertionError, match='value mismatch for column "y"'):
227
assert_frame_equal(df1, df2, check_exact=True)
228
229
# struct dtype
230
df3 = pl.from_dicts(
231
[
232
{
233
"id": 1,
234
"struct": [
235
{"x": "text", "y": [0.0, nan]},
236
{"x": "text", "y": [0.0, nan]},
237
],
238
},
239
{
240
"id": 2,
241
"struct": [
242
{"x": "text", "y": [1]},
243
{"x": "text", "y": [1]},
244
],
245
},
246
]
247
)
248
df4 = pl.from_dicts(
249
[
250
{
251
"id": 1,
252
"struct": [
253
{"x": "text", "y": [0.0, nan], "z": ["$"]},
254
{"x": "text", "y": [0.0, nan], "z": ["$"]},
255
],
256
},
257
{
258
"id": 2,
259
"struct": [
260
{"x": "text", "y": [nan, 1.0], "z": ["!"]},
261
{"x": "text", "y": [nan, 1.0], "z": ["?"]},
262
],
263
},
264
]
265
)
266
267
assert_frame_equal(df3, df3)
268
assert_frame_equal(df4, df4)
269
270
assert_frame_not_equal(df3, df4)
271
for check_dtype in (True, False):
272
with pytest.raises(AssertionError, match=r"mismatch|different"):
273
assert_frame_equal(df3, df4, check_dtypes=check_dtype)
274
275
276
def test_assert_frame_equal_pass() -> None:
277
df1 = pl.DataFrame({"a": [1, 2]})
278
df2 = pl.DataFrame({"a": [1, 2]})
279
assert_frame_equal(df1, df2)
280
281
282
@pytest.mark.parametrize(
283
"assert_function",
284
[assert_frame_equal, assert_frame_not_equal],
285
)
286
def test_assert_frame_equal_types(assert_function: Any) -> None:
287
df1 = pl.DataFrame({"a": [1, 2]})
288
srs1 = pl.Series(values=[1, 2], name="a")
289
with pytest.raises(
290
AssertionError, match=r"inputs are different \(unexpected input types\)"
291
):
292
assert_function(df1, srs1)
293
294
295
def test_assert_frame_equal_length_mismatch() -> None:
296
df1 = pl.DataFrame({"a": [1, 2]})
297
df2 = pl.DataFrame({"a": [1, 2, 3]})
298
with pytest.raises(
299
AssertionError,
300
match=r"DataFrames are different \(height \(row count\) mismatch\)",
301
):
302
assert_frame_equal(df1, df2)
303
assert_frame_not_equal(df1, df2)
304
305
306
def test_assert_frame_equal_column_mismatch() -> None:
307
df1 = pl.DataFrame({"a": [1, 2]})
308
df2 = pl.DataFrame({"b": [1, 2]})
309
with pytest.raises(
310
AssertionError,
311
match=r'DataFrames are different \(columns mismatch: \["a"\] in left, but not in right\)',
312
):
313
assert_frame_equal(df1, df2)
314
assert_frame_not_equal(df1, df2)
315
316
317
def test_assert_frame_equal_column_mismatch2() -> None:
318
df1 = pl.LazyFrame({"a": [1, 2]})
319
df2 = pl.LazyFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
320
with pytest.raises(
321
AssertionError,
322
match=r"columns mismatch.*in right.*but not in left",
323
):
324
assert_frame_equal(df1, df2)
325
assert_frame_not_equal(df1, df2)
326
327
328
def test_assert_frame_equal_column_mismatch_order() -> None:
329
df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]})
330
df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
331
with pytest.raises(AssertionError, match="columns are not in the same order"):
332
assert_frame_equal(df1, df2)
333
334
assert_frame_equal(df1, df2, check_column_order=False)
335
assert_frame_not_equal(df1, df2)
336
337
338
def test_assert_frame_equal_check_row_order() -> None:
339
df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]})
340
df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]})
341
342
with pytest.raises(AssertionError, match='value mismatch for column "a"'):
343
assert_frame_equal(df1, df2)
344
345
assert_frame_equal(df1, df2, check_row_order=False)
346
assert_frame_not_equal(df1, df2)
347
348
349
def test_assert_frame_equal_check_row_col_order() -> None:
350
df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]})
351
df2 = pl.DataFrame({"b": [3, 4], "a": [2, 1]})
352
353
with pytest.raises(AssertionError, match="columns are not in the same order"):
354
assert_frame_equal(df1, df2, check_row_order=False)
355
356
assert_frame_equal(df1, df2, check_row_order=False, check_column_order=False)
357
assert_frame_not_equal(df1, df2)
358
359
360
@pytest.mark.parametrize(
361
"assert_function",
362
[assert_frame_equal, assert_frame_not_equal],
363
)
364
def test_assert_frame_equal_check_row_order_unsortable(assert_function: Any) -> None:
365
df1 = pl.DataFrame({"a": [object(), object()], "b": [3, 4]})
366
df2 = pl.DataFrame({"a": [object(), object()], "b": [4, 3]})
367
with pytest.raises(
368
InvalidOperationError,
369
match="`arg_sort_multiple` operation not supported for dtype `object`",
370
):
371
assert_function(df1, df2, check_row_order=False)
372
373
374
def test_assert_frame_equal_dtypes_mismatch() -> None:
375
data = {"a": [1, 2], "b": [3, 4]}
376
df1 = pl.DataFrame(data, schema={"a": pl.Int8, "b": pl.Int16})
377
df2 = pl.DataFrame(data, schema={"b": pl.Int16, "a": pl.Int16})
378
379
with pytest.raises(AssertionError, match="dtypes do not match"):
380
assert_frame_equal(df1, df2, check_column_order=False)
381
382
assert_frame_not_equal(df1, df2, check_column_order=False)
383
assert_frame_not_equal(df1, df2)
384
385
386
def test_assert_frame_not_equal() -> None:
387
df = pl.DataFrame({"a": [1, 2]})
388
with pytest.raises(AssertionError, match="DataFrames are equal"):
389
assert_frame_not_equal(df, df)
390
lf = df.lazy()
391
with pytest.raises(AssertionError, match="LazyFrames are equal"):
392
assert_frame_not_equal(lf, lf)
393
394
395
def test_assert_frame_equal_check_dtype_deprecated() -> None:
396
df1 = pl.DataFrame({"a": [1, 2]})
397
df2 = pl.DataFrame({"a": [1.0, 2.0]})
398
df3 = pl.DataFrame({"a": [2, 1]})
399
400
with pytest.deprecated_call():
401
assert_frame_equal(df1, df2, check_dtype=False) # type: ignore[call-arg]
402
403
with pytest.deprecated_call():
404
assert_frame_not_equal(df1, df3, check_dtype=False) # type: ignore[call-arg]
405
406
407
def test_assert_dataframe_equal_all_nulls_passes_when_ignoring_dtypes() -> None:
408
x = pl.from_dict({"A": [None, None, None]})
409
y = pl.from_dict(
410
{"A": [None, None, None]}, schema_overrides={"A": pl.List(pl.Float64())}
411
)
412
413
assert_frame_equal(x, y, check_dtypes=False)
414
415
416
def test_assert_dataframe_equal_all_nulls_fails_when_checking_dtypes() -> None:
417
x = pl.from_dict({"A": [None, None, None]})
418
y = pl.from_dict(
419
{"A": [None, None, None]}, schema_overrides={"A": pl.List(pl.Float64())}
420
)
421
422
with pytest.raises(AssertionError, match="dtypes do not match"):
423
assert_frame_equal(x, y, check_dtypes=True)
424
425
426
def test_tracebackhide(testdir: pytest.Testdir) -> None:
427
testdir.makefile(
428
".py",
429
test_path="""\
430
import polars as pl
431
from polars.testing import assert_frame_equal, assert_frame_not_equal
432
433
def test_frame_equal_fail():
434
df1 = pl.DataFrame({"a": [1, 2]})
435
df2 = pl.DataFrame({"a": [1, 3]})
436
assert_frame_equal(df1, df2)
437
438
def test_frame_not_equal_fail():
439
df1 = pl.DataFrame({"a": [1, 2]})
440
df2 = pl.DataFrame({"a": [1, 2]})
441
assert_frame_not_equal(df1, df2)
442
443
def test_frame_data_type_fail():
444
df1 = pl.DataFrame({"a": [1, 2]})
445
df2 = {"a": [1, 2]}
446
assert_frame_equal(df1, df2)
447
448
def test_frame_schema_fail():
449
df1 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int64})
450
df2 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int32})
451
assert_frame_equal(df1, df2)
452
""",
453
)
454
result = testdir.runpytest()
455
result.assert_outcomes(passed=0, failed=4)
456
stdout = "\n".join(result.outlines)
457
458
assert "polars/py-polars/polars/testing" not in stdout
459
460
# The above should catch any polars testing functions that appear in the
461
# stack trace. But we keep the following checks (for specific function
462
# names) just to double-check.
463
464
assert "def assert_frame_equal" not in stdout
465
assert "def assert_frame_not_equal" not in stdout
466
assert "def _assert_correct_input_type" not in stdout
467
468
assert "def assert_series_equal" not in stdout
469
assert "def assert_series_not_equal" not in stdout
470
471
# Make sure the tests are failing for the expected reason (e.g. not because
472
# an import is missing or something like that):
473
474
assert (
475
'AssertionError: DataFrames are different (value mismatch for column "a")'
476
in stdout
477
)
478
assert "AssertionError: DataFrames are equal" in stdout
479
assert "AssertionError: inputs are different (unexpected input types)" in stdout
480
assert "AssertionError: DataFrames are different (dtypes do not match)" in stdout
481
482