Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/testing/test_assert_series_equal.py
8410 views
1
from __future__ import annotations
2
3
import math
4
from datetime import datetime, time, timedelta
5
from decimal import Decimal as D
6
from typing import Any
7
8
import hypothesis.strategies as st
9
import pytest
10
from hypothesis import given
11
12
import polars as pl
13
from polars.exceptions import InvalidOperationError
14
from polars.testing import assert_series_equal, assert_series_not_equal
15
from polars.testing.parametric import dtypes, series
16
17
nan = float("nan")
18
pytest_plugins = ["pytester"]
19
20
21
@given(s=series())
22
def test_assert_series_equal_parametric(s: pl.Series) -> None:
23
assert_series_equal(s, s)
24
25
26
@given(data=st.data())
27
def test_assert_series_equal_parametric_array(data: st.DataObject) -> None:
28
inner = data.draw(dtypes(excluded_dtypes=[pl.Categorical]))
29
shape = data.draw(st.integers(min_value=1, max_value=3))
30
dtype = pl.Array(inner, shape=shape)
31
s = data.draw(series(dtype=dtype))
32
33
assert_series_equal(s, s)
34
35
36
def test_compare_series_value_mismatch() -> None:
37
srs1 = pl.Series([1, 2, 3])
38
srs2 = pl.Series([2, 3, 4])
39
assert_series_not_equal(srs1, srs2)
40
41
with pytest.raises(
42
AssertionError,
43
match=r"Series are different \(exact value mismatch\)",
44
):
45
assert_series_equal(srs1, srs2)
46
47
48
def test_compare_series_empty_equal() -> None:
49
srs1 = pl.Series([])
50
srs2 = pl.Series(())
51
assert_series_equal(srs1, srs2)
52
53
with pytest.raises(
54
AssertionError,
55
match=r"Series are equal \(but are expected not to be\)",
56
):
57
assert_series_not_equal(srs1, srs2)
58
59
60
def test_assert_series_equal_check_order() -> None:
61
srs1 = pl.Series([1, 2, 3, None])
62
srs2 = pl.Series([2, None, 3, 1])
63
assert_series_equal(srs1, srs2, check_order=False)
64
65
with pytest.raises(
66
AssertionError,
67
match=r"Series are equal \(but are expected not to be\)",
68
):
69
assert_series_not_equal(srs1, srs2, check_order=False)
70
71
72
def test_assert_series_equal_check_order_unsortable_type() -> None:
73
s1 = pl.Series([object(), object()])
74
s2 = pl.Series([object(), object()])
75
with pytest.raises(
76
InvalidOperationError,
77
match="`sort_with` operation not supported for dtype `object`",
78
):
79
assert_series_equal(s1, s2, check_order=False)
80
81
82
def test_compare_series_nans_assert_equal() -> None:
83
srs1 = pl.Series([1.0, 2.0, nan, 4.0, None, 6.0])
84
srs2 = pl.Series([1.0, nan, 3.0, 4.0, None, 6.0])
85
srs3 = pl.Series([1.0, 2.0, 3.0, 4.0, None, 6.0])
86
87
for srs in (srs1, srs2, srs3):
88
assert_series_equal(srs, srs)
89
assert_series_equal(srs, srs, check_exact=True)
90
91
for check_exact in (False, True):
92
if check_exact:
93
check_msg = "exact value mismatch"
94
else:
95
check_msg = "Series are different.*value mismatch.*"
96
97
with pytest.raises(AssertionError, match=check_msg):
98
assert_series_equal(srs1, srs2, check_exact=check_exact)
99
with pytest.raises(AssertionError, match=check_msg):
100
assert_series_equal(srs1, srs3, check_exact=check_exact)
101
102
srs4 = pl.Series([1.0, 2.0, 3.0, 4.0, None, 6.0])
103
srs5 = pl.Series([1.0, 2.0, 3.0, 4.0, nan, 6.0])
104
srs6 = pl.Series([1, 2, 3, 4, None, 6])
105
106
assert_series_equal(srs4, srs6, check_dtypes=False)
107
with pytest.raises(AssertionError):
108
assert_series_equal(srs5, srs6, check_dtypes=False)
109
assert_series_not_equal(srs5, srs6, check_dtypes=True)
110
111
# nested
112
for float_type in (pl.Float32, pl.Float64):
113
srs = pl.Series([[0.0, nan]], dtype=pl.List(float_type))
114
assert srs.dtype == pl.List(float_type)
115
assert_series_equal(srs, srs)
116
117
118
def test_compare_series_nulls() -> None:
119
srs1 = pl.Series([1, 2, None])
120
srs2 = pl.Series([1, 2, None])
121
assert_series_equal(srs1, srs2)
122
123
srs1 = pl.Series([1, 2, 3])
124
srs2 = pl.Series([1, None, None])
125
assert_series_not_equal(srs1, srs2)
126
127
with pytest.raises(AssertionError, match="value mismatch"):
128
assert_series_equal(srs1, srs2)
129
130
131
def test_compare_series_value_mismatch_string() -> None:
132
srs1 = pl.Series(["hello", "no"])
133
srs2 = pl.Series(["hello", "yes"])
134
135
assert_series_not_equal(srs1, srs2)
136
with pytest.raises(
137
AssertionError,
138
match=r"Series are different \(exact value mismatch\)",
139
):
140
assert_series_equal(srs1, srs2)
141
142
143
def test_compare_series_dtype_mismatch() -> None:
144
srs1 = pl.Series([1, 2, 3])
145
srs2 = pl.Series([1.0, 2.0, 3.0])
146
assert_series_not_equal(srs1, srs2)
147
148
with pytest.raises(
149
AssertionError,
150
match=r"Series are different \(dtype mismatch\)",
151
):
152
assert_series_equal(srs1, srs2)
153
154
155
@pytest.mark.parametrize(
156
"assert_function", [assert_series_equal, assert_series_not_equal]
157
)
158
def test_compare_series_input_type_mismatch(assert_function: Any) -> None:
159
srs1 = pl.Series([1, 2, 3])
160
srs2 = pl.DataFrame({"col1": [2, 3, 4]})
161
162
with pytest.raises(
163
AssertionError,
164
match=r"inputs are different \(unexpected input types\)",
165
):
166
assert_function(srs1, srs2)
167
168
169
def test_compare_series_name_mismatch() -> None:
170
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
171
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
172
with pytest.raises(
173
AssertionError,
174
match=r"Series are different \(name mismatch\)",
175
):
176
assert_series_equal(srs1, srs2)
177
178
179
def test_compare_series_length_mismatch() -> None:
180
srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1")
181
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
182
183
assert_series_not_equal(srs1, srs2)
184
with pytest.raises(
185
AssertionError,
186
match=r"Series are different \(length mismatch\)",
187
):
188
assert_series_equal(srs1, srs2)
189
190
191
def test_compare_series_value_exact_mismatch() -> None:
192
srs1 = pl.Series([1.0, 2.0, 3.0])
193
srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0])
194
with pytest.raises(
195
AssertionError,
196
match=r"Series are different \(exact value mismatch\)",
197
):
198
assert_series_equal(srs1, srs2, check_exact=True)
199
200
201
def test_assert_series_equal_int_overflow() -> None:
202
# internally may call 'abs' if not check_exact, which can overflow on signed int
203
s0 = pl.Series([-128], dtype=pl.Int8)
204
s1 = pl.Series([0, -128], dtype=pl.Int8)
205
s2 = pl.Series([1, -128], dtype=pl.Int8)
206
207
for check_exact in (True, False):
208
assert_series_equal(s0, s0, check_exact=check_exact)
209
with pytest.raises(AssertionError):
210
assert_series_equal(s1, s2, check_exact=check_exact)
211
212
213
@pytest.mark.parametrize(
214
("data1", "data2"),
215
[
216
([datetime(2022, 10, 2, 12)], [datetime(2022, 10, 2, 13)]),
217
([time(10, 0, 0)], [time(10, 0, 10)]),
218
([timedelta(10, 0, 0)], [timedelta(10, 0, 10)]),
219
],
220
)
221
def test_assert_series_equal_temporal(data1: Any, data2: Any) -> None:
222
s1 = pl.Series(data1)
223
s2 = pl.Series(data2)
224
assert_series_not_equal(s1, s2)
225
226
227
@pytest.mark.parametrize(
228
("s1", "s2", "kwargs"),
229
[
230
pytest.param(
231
pl.Series([0.2, 0.3]),
232
pl.Series([0.2, 0.3]),
233
{"abs_tol": 1e-15},
234
id="equal_floats_low_abs_tol",
235
),
236
pytest.param(
237
pl.Series([0.2, 0.3]),
238
pl.Series([0.2, 0.3000000000000001]),
239
{"abs_tol": 1e-15},
240
id="approx_equal_float_low_abs_tol",
241
),
242
pytest.param(
243
pl.Series([0.2, 0.3]),
244
pl.Series([0.2, 0.31]),
245
{"abs_tol": 0.1},
246
id="approx_equal_float_high_abs_tol",
247
),
248
pytest.param(
249
pl.Series([0.2, 1.3]),
250
pl.Series([0.2, 0.9]),
251
{"abs_tol": 1},
252
id="approx_equal_float_integer_abs_tol",
253
),
254
pytest.param(
255
pl.Series([1.0, 2.0, nan]),
256
pl.Series([1.005, 2.005, nan]),
257
{"abs_tol": 1e-2, "rel_tol": 0.0},
258
id="approx_equal_float_nan_abs_tol",
259
),
260
pytest.param(
261
pl.Series([1.0, 2.0, None]),
262
pl.Series([1.005, 2.005, None]),
263
{"abs_tol": 1e-2},
264
id="approx_equal_float_none_abs_tol",
265
),
266
pytest.param(
267
pl.Series([1.0, 2.0, nan]),
268
pl.Series([1.005, 2.015, nan]),
269
{"abs_tol": 0.0, "rel_tol": 1e-2},
270
id="approx_equal_float_nan_rel_tol",
271
),
272
pytest.param(
273
pl.Series([1.0, 2.0, None]),
274
pl.Series([1.005, 2.015, None]),
275
{"rel_tol": 1e-2},
276
id="approx_equal_float_none_rel_tol",
277
),
278
pytest.param(
279
pl.Series([0.0, 1.0, 2.0], dtype=pl.Float64),
280
pl.Series([0, 1, 2], dtype=pl.Int64),
281
{"check_dtypes": False},
282
id="equal_int_float_integer_no_check_dtype",
283
),
284
pytest.param(
285
pl.Series([0, 1, 2], dtype=pl.Float64),
286
pl.Series([0, 1, 2], dtype=pl.Float32),
287
{"check_dtypes": False},
288
id="equal_int_float_integer_no_check_dtype",
289
),
290
pytest.param(
291
pl.Series([0, 1, 2], dtype=pl.Int64),
292
pl.Series([0, 1, 2], dtype=pl.Int64),
293
{},
294
id="equal_int",
295
),
296
pytest.param(
297
pl.Series(["a", "b", "c"], dtype=pl.String),
298
pl.Series(["a", "b", "c"], dtype=pl.String),
299
{},
300
id="equal_str",
301
),
302
pytest.param(
303
pl.Series([[0.2, 0.3]]),
304
pl.Series([[0.2, 0.31]]),
305
{"abs_tol": 0.1},
306
id="list_of_float_high_abs_tol",
307
),
308
pytest.param(
309
pl.Series([[0.2, 1.3]]),
310
pl.Series([[0.2, 0.9]]),
311
{"abs_tol": 1},
312
id="list_of_float_integer_abs_tol",
313
),
314
pytest.param(
315
pl.Series([[0.2, 0.3]]),
316
pl.Series([[0.2, 0.300000001]]),
317
{"rel_tol": 1e-15},
318
id="list_of_float_low_rel_tol",
319
),
320
pytest.param(
321
pl.Series([[0.2, 0.3]]),
322
pl.Series([[0.2, 0.301]]),
323
{"rel_tol": 0.1},
324
id="list_of_float_high_rel_tol",
325
),
326
pytest.param(
327
pl.Series([[0.2, 1.3]]),
328
pl.Series([[0.2, 0.9]]),
329
{"rel_tol": 1},
330
id="list_of_float_integer_rel_tol",
331
),
332
pytest.param(
333
pl.Series([[None, 1.3]]),
334
pl.Series([[None, 0.9]]),
335
{"rel_tol": 1},
336
id="list_of_none_and_float_integer_rel_tol",
337
),
338
pytest.param(
339
pl.Series([[None, 1]], dtype=pl.List(pl.Int64)),
340
pl.Series([[None, 1]], dtype=pl.List(pl.Int64)),
341
{"rel_tol": 1},
342
id="list_of_none_and_int_integer_rel_tol",
343
),
344
pytest.param(
345
pl.Series([[math.nan, 1.3]]),
346
pl.Series([[math.nan, 0.9]]),
347
{"rel_tol": 1},
348
id="list_of_none_and_float_integer_rel_tol",
349
),
350
pytest.param(
351
pl.Series([[2.0, 3.0]]),
352
pl.Series([[2, 3]]),
353
{"check_exact": False, "check_dtypes": False},
354
id="list_of_float_list_of_int_check_dtype_false",
355
),
356
pytest.param(
357
pl.Series([[[0.2, 3.0]]]),
358
pl.Series([[[0.2, 3.00000001]]]),
359
{"abs_tol": 0.1},
360
id="nested_list_of_float_abs_tol_high",
361
),
362
pytest.param(
363
pl.Series([[[0.2, math.nan, 3.0]]]),
364
pl.Series([[[0.2, math.nan, 3.00000001]]]),
365
{"abs_tol": 0.1},
366
id="nested_list_of_float_and_nan_abs_tol_high",
367
),
368
pytest.param(
369
pl.Series([[[[0.2, 3.0]]]]),
370
pl.Series([[[[0.2, 3.00000001]]]]),
371
{"abs_tol": 0.1},
372
id="double_nested_list_of_float_abs_tol_high",
373
),
374
pytest.param(
375
pl.Series([[[[0.2, math.nan, 3.0]]]]),
376
pl.Series([[[[0.2, math.nan, 3.00000001]]]]),
377
{"abs_tol": 0.1},
378
id="double_nested_list_of_float__and_nan_abs_tol_high",
379
),
380
pytest.param(
381
pl.Series([[[[[0.2, 3.0]]]]]),
382
pl.Series([[[[[0.2, 3.00000001]]]]]),
383
{"abs_tol": 0.1},
384
id="triple_nested_list_of_float_abs_tol_high",
385
),
386
pytest.param(
387
pl.Series([[[[[0.2, math.nan, 3.0]]]]]),
388
pl.Series([[[[[0.2, math.nan, 3.00000001]]]]]),
389
{"abs_tol": 0.1},
390
id="triple_nested_list_of_float_and_nan_abs_tol_high",
391
),
392
pytest.param(
393
pl.struct(a=0, b=1, eager=True),
394
pl.struct(a=0, b=1, eager=True),
395
{},
396
id="struct_equal",
397
),
398
pytest.param(
399
pl.struct(a=0, b=1.1, eager=True),
400
pl.struct(a=0, b=1.01, eager=True),
401
{"abs_tol": 0.1, "rel_tol": 0},
402
id="struct_approx_equal",
403
),
404
pytest.param(
405
pl.struct(a=0, b=[0.0, 1.1], eager=True),
406
pl.struct(a=0, b=[0.0, 1.11], eager=True),
407
{"abs_tol": 0.1},
408
id="struct_with_list_approx_equal",
409
),
410
pytest.param(
411
pl.struct(a=0, b=[0.0, math.nan], eager=True),
412
pl.struct(a=0, b=[0.0, math.nan], eager=True),
413
{"abs_tol": 0.1},
414
id="struct_with_list_with_nan_compare_equal_true",
415
),
416
],
417
)
418
def test_assert_series_equal_passes_assertion(
419
s1: pl.Series,
420
s2: pl.Series,
421
kwargs: Any,
422
) -> None:
423
assert_series_equal(s1, s2, **kwargs)
424
with pytest.raises(AssertionError):
425
assert_series_not_equal(s1, s2, **kwargs)
426
427
428
@pytest.mark.parametrize(
429
("s1", "s2", "kwargs"),
430
[
431
pytest.param(
432
pl.Series([0.2, 0.3]),
433
pl.Series([0.2, 0.39]),
434
{"abs_tol": 0.09, "rel_tol": 0},
435
id="approx_equal_float_high_abs_tol_zero_rel_tol",
436
),
437
pytest.param(
438
pl.Series([0.2, 1.3]),
439
pl.Series([0.2, 2.31]),
440
{"abs_tol": 1, "rel_tol": 0},
441
id="approx_equal_float_integer_abs_tol_zero_rel_tol",
442
),
443
pytest.param(
444
pl.Series([0, 1, 2], dtype=pl.Float64),
445
pl.Series([0, 1, 2], dtype=pl.Int64),
446
{"check_dtypes": True},
447
id="equal_int_float_integer_check_dtype",
448
),
449
pytest.param(
450
pl.Series([0, 1, 2], dtype=pl.Float64),
451
pl.Series([0, 1, 2], dtype=pl.Float32),
452
{"check_dtypes": True},
453
id="equal_int_float_integer_check_dtype",
454
),
455
pytest.param(
456
pl.Series([1.0, 2.0, nan]),
457
pl.Series([1.005, 2.005, 3.005]),
458
{"abs_tol": 1e-2, "rel_tol": 0.0},
459
id="approx_equal_float_left_nan_abs_tol",
460
),
461
pytest.param(
462
pl.Series([1.0, 2.0, 3.0]),
463
pl.Series([1.005, 2.005, nan]),
464
{"abs_tol": 1e-2, "rel_tol": 0.0},
465
id="approx_equal_float_right_nan_abs_tol",
466
),
467
pytest.param(
468
pl.Series([1.0, 2.0, nan]),
469
pl.Series([1.005, 2.015, 3.025]),
470
{"abs_tol": 0.0, "rel_tol": 1e-2},
471
id="approx_equal_float_left_nan_rel_tol",
472
),
473
pytest.param(
474
pl.Series([1.0, 2.0, 3.0]),
475
pl.Series([1.005, 2.015, nan]),
476
{"abs_tol": 0.0, "rel_tol": 1e-2},
477
id="approx_equal_float_right_nan_rel_tol",
478
),
479
pytest.param(
480
pl.Series([[0.2, 0.3]]),
481
pl.Series([[0.2, 0.3, 0.4]]),
482
{},
483
id="list_of_float_different_lengths",
484
),
485
pytest.param(
486
pl.Series([[0.2, 0.3]]),
487
pl.Series([[0.2, 0.3000000000000001]]),
488
{"check_exact": True},
489
id="list_of_float_check_exact",
490
),
491
pytest.param(
492
pl.Series([[0.2, 0.3]]),
493
pl.Series([[0.2, 0.300001]]),
494
{"abs_tol": 1e-15, "rel_tol": 0},
495
id="list_of_float_too_low_abs_tol",
496
),
497
pytest.param(
498
pl.Series([[0.2, 0.3]]),
499
pl.Series([[0.2, 0.30000001]]),
500
{"abs_tol": -1, "rel_tol": 0},
501
id="list_of_float_negative_abs_tol",
502
),
503
pytest.param(
504
pl.Series([[2.0, 3.0]]),
505
pl.Series([[2, 3]]),
506
{"check_exact": False, "check_dtypes": True},
507
id="list_of_float_list_of_int_check_dtype_true",
508
),
509
pytest.param(
510
pl.struct(a=0, b=1.1, eager=True),
511
pl.struct(a=0, b=1, eager=True),
512
{"abs_tol": 0.1, "rel_tol": 0, "check_dtypes": True},
513
id="struct_approx_equal_different_type",
514
),
515
pytest.param(
516
pl.struct(a=0, b=1.09, eager=True),
517
pl.struct(a=0, b=1, eager=True),
518
{"abs_tol": 0.1, "rel_tol": 0, "check_dtypes": False},
519
id="struct_approx_equal_different_type",
520
),
521
],
522
)
523
def test_assert_series_equal_raises_assertion_error(
524
s1: pl.Series,
525
s2: pl.Series,
526
kwargs: Any,
527
) -> None:
528
with pytest.raises(AssertionError):
529
assert_series_equal(s1, s2, **kwargs)
530
assert_series_not_equal(s1, s2, **kwargs)
531
532
533
def test_assert_series_equal_categorical_vs_str() -> None:
534
s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical)
535
s2 = pl.Series(["a", "b", "a"], dtype=pl.String)
536
537
with pytest.raises(AssertionError, match="dtype mismatch"):
538
assert_series_equal(s1, s2, categorical_as_str=True)
539
540
assert_series_equal(s1, s2, check_dtypes=False, categorical_as_str=True)
541
assert_series_equal(s2, s1, check_dtypes=False, categorical_as_str=True)
542
543
544
def test_assert_series_equal_incompatible_data_types() -> None:
545
s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical)
546
s2 = pl.Series([0, 1, 0], dtype=pl.Int8)
547
548
with pytest.raises(AssertionError, match="incompatible data types"):
549
assert_series_equal(s1, s2, check_dtypes=False)
550
551
552
def test_assert_series_equal_full_series() -> None:
553
s1 = pl.Series([1, 2, 3])
554
s2 = pl.Series([1, 2, 4])
555
with pytest.raises(
556
AssertionError, match=r"Series are different \(exact value mismatch\)"
557
):
558
assert_series_equal(s1, s2)
559
560
561
def test_assert_series_not_equal() -> None:
562
s = pl.Series("a", [1, 2])
563
with pytest.raises(
564
AssertionError,
565
match=r"Series are equal \(but are expected not to be\)",
566
):
567
assert_series_not_equal(s, s)
568
569
570
def test_assert_series_equal_nested_list_float() -> None:
571
# First entry has only integers
572
s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64))
573
s2 = pl.Series([[1.0, 2.0], [3.0, 4.9]], dtype=pl.List(pl.Float64))
574
575
with pytest.raises(
576
AssertionError,
577
match=r"Series are different \(nested value mismatch\)",
578
):
579
assert_series_equal(s1, s2)
580
581
582
def test_assert_series_equal_nested_struct_float() -> None:
583
s1 = pl.Series(
584
[{"a": 1.0, "b": 2.0}, {"a": 3.0, "b": 4.0}],
585
dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}),
586
)
587
s2 = pl.Series(
588
[{"a": 1.0, "b": 2.0}, {"a": 3.0, "b": 4.9}],
589
dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}),
590
)
591
592
with pytest.raises(
593
AssertionError,
594
match=r"Series are different \(nested value mismatch\)",
595
):
596
assert_series_equal(s1, s2)
597
598
599
def test_assert_series_equal_all_null_different_dtypes_fails_with_check_dtypes_true() -> (
600
None
601
):
602
s1 = pl.Series([None, None], dtype=pl.Categorical)
603
s2 = pl.Series([None, None], dtype=pl.Int16)
604
605
with pytest.raises(
606
AssertionError,
607
match="dtype mismatch",
608
):
609
assert_series_equal(s1, s2, check_dtypes=True)
610
611
612
def test_assert_series_equal_all_null_different_dtypes_passes_with_check_dtypes_false() -> (
613
None
614
):
615
s1 = pl.Series([None, None], dtype=pl.Categorical)
616
s2 = pl.Series([None, None], dtype=pl.Int16)
617
618
assert_series_equal(s1, s2, check_dtypes=False)
619
620
621
def test_assert_series_equal_full_null_nested_list() -> None:
622
s = pl.Series([None, None], dtype=pl.List(pl.Float64))
623
assert_series_equal(s, s)
624
625
626
def test_assert_series_equal_nested_list_nan() -> None:
627
s = pl.Series([[1.0, 2.0], [3.0, nan]], dtype=pl.List(pl.Float64))
628
assert_series_equal(s, s)
629
630
631
def test_assert_series_equal_nested_list_none() -> None:
632
s1 = pl.Series([[1.0, 2.0], None], dtype=pl.List(pl.Float64))
633
s2 = pl.Series([[1.0, 2.0], None], dtype=pl.List(pl.Float64))
634
635
assert_series_equal(s1, s2)
636
637
638
def test_assert_series_equal_uint_overflow() -> None:
639
s1 = pl.Series([1, 2, 3], dtype=pl.UInt8)
640
s2 = pl.Series([2, 3, 4], dtype=pl.UInt8)
641
642
with pytest.raises(
643
AssertionError,
644
match=r"Series are different \(exact value mismatch\)",
645
):
646
assert_series_equal(s1, s2, abs_tol=0)
647
648
with pytest.raises(
649
AssertionError,
650
match=r"Series are different \(exact value mismatch\)",
651
):
652
assert_series_equal(s1, s2, abs_tol=1)
653
654
left = pl.Series(
655
values=[2810428175213635359],
656
dtype=pl.UInt64,
657
)
658
right = pl.Series(
659
values=[15807433754238349345],
660
dtype=pl.UInt64,
661
)
662
with pytest.raises(AssertionError):
663
assert_series_equal(left, right)
664
665
666
def test_assert_series_equal_uint_always_checked_exactly() -> None:
667
s1 = pl.Series([1, 3], dtype=pl.UInt8)
668
s2 = pl.Series([2, 4], dtype=pl.Int64)
669
670
with pytest.raises(
671
AssertionError,
672
match=r"Series are different \(exact value mismatch\)",
673
):
674
assert_series_equal(s1, s2, abs_tol=1, check_dtypes=False)
675
676
677
def test_assert_series_equal_nested_int_always_checked_exactly() -> None:
678
s1 = pl.Series([[1, 2], [3, 4]])
679
s2 = pl.Series([[1, 2], [3, 5]])
680
681
with pytest.raises(
682
AssertionError,
683
match=r"Series are different \(exact value mismatch\)",
684
):
685
assert_series_equal(s1, s2, abs_tol=1)
686
with pytest.raises(
687
AssertionError,
688
match=r"Series are different \(exact value mismatch\)",
689
):
690
assert_series_equal(s1, s2, check_exact=True)
691
692
693
@pytest.mark.parametrize("check_exact", [True, False])
694
def test_assert_series_equal_array_equal(check_exact: bool) -> None:
695
s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.Array(pl.Float64, 2))
696
s2 = pl.Series([[1.0, 2.0], [3.0, 4.2]], dtype=pl.Array(pl.Float64, 2))
697
698
with pytest.raises(
699
AssertionError, match=r"Series are different \(nested value mismatch\)"
700
):
701
assert_series_equal(s1, s2, check_exact=check_exact)
702
703
704
def test_series_equal_nested_lengths_mismatch() -> None:
705
s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64))
706
s2 = pl.Series([[1.0, 2.0, 3.0], [4.0]], dtype=pl.List(pl.Float64))
707
708
with pytest.raises(AssertionError, match="nested value mismatch"):
709
assert_series_equal(s1, s2)
710
711
712
@pytest.mark.parametrize("check_exact", [True, False])
713
def test_series_equal_decimals(check_exact: bool) -> None:
714
s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal)
715
s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal)
716
717
assert_series_equal(s1, s1, check_exact=check_exact)
718
assert_series_equal(s2, s2, check_exact=check_exact)
719
720
with pytest.raises(AssertionError, match="exact value mismatch"):
721
assert_series_equal(s1, s2, check_exact=check_exact)
722
723
724
def test_assert_series_equal_w_large_integers_12328() -> None:
725
left = pl.Series([1577840521123000])
726
right = pl.Series([1577840521123543])
727
with pytest.raises(AssertionError):
728
assert_series_equal(left, right)
729
730
731
def test_assert_series_equal_check_dtype_deprecated() -> None:
732
s1 = pl.Series("a", [1, 2])
733
s2 = pl.Series("a", [1.0, 2.0])
734
s3 = pl.Series("a", [2, 1])
735
736
with pytest.deprecated_call():
737
assert_series_equal(s1, s2, check_dtype=False) # type: ignore[call-arg]
738
739
with pytest.deprecated_call():
740
assert_series_not_equal(s1, s3, check_dtype=False) # type: ignore[call-arg]
741
742
743
def test_assert_series_equal_nested_categorical_as_str_independently_constructed() -> (
744
None
745
):
746
# https://github.com/pola-rs/polars/issues/16196
747
s1 = pl.Series(["c0"], dtype=pl.Categorical)
748
s2 = pl.Series(["c1"], dtype=pl.Categorical)
749
a = pl.DataFrame([s1, s2]).to_struct("col0")
750
751
s1 = pl.Series(["c0"], dtype=pl.Categorical)
752
s2 = pl.Series(["c1"], dtype=pl.Categorical)
753
b = pl.DataFrame([s1, s2]).to_struct("col0")
754
755
assert_series_equal(a, b, categorical_as_str=True)
756
assert_series_equal(a, b, categorical_as_str=False)
757
758
759
@pytest.mark.parametrize(
760
"s",
761
[
762
pl.Series([["a", "b"], ["a"]], dtype=pl.List(pl.Categorical)),
763
pl.Series([{"a": "x"}, {"a": "y"}], dtype=pl.Struct({"a": pl.Categorical})),
764
],
765
)
766
def test_assert_series_equal_nested_categorical_as_str(s: pl.Series) -> None:
767
assert_series_equal(s, s, categorical_as_str=True)
768
769
770
def test_tracebackhide(testdir: pytest.Testdir) -> None:
771
testdir.makefile(
772
".py",
773
test_path="""\
774
import polars as pl
775
from polars.testing import assert_series_equal, assert_series_not_equal
776
777
nan = float("nan")
778
779
def test_series_equal_fail():
780
s1 = pl.Series([1, 2])
781
s2 = pl.Series([1, 3])
782
assert_series_equal(s1, s2)
783
784
def test_series_not_equal_fail():
785
s1 = pl.Series([1, 2])
786
s2 = pl.Series([1, 2])
787
assert_series_not_equal(s1, s2)
788
789
def test_series_nested_fail():
790
s1 = pl.Series([[1, 2], [3, 4]])
791
s2 = pl.Series([[1, 2], [3, 5]])
792
assert_series_equal(s1, s2)
793
794
def test_series_null_fail():
795
s1 = pl.Series([1, 2])
796
s2 = pl.Series([1, None])
797
assert_series_equal(s1, s2)
798
799
def test_series_nan_fail():
800
s1 = pl.Series([1.0, 2.0])
801
s2 = pl.Series([1.0, nan])
802
assert_series_equal(s1, s2)
803
804
def test_series_float_tolerance_fail():
805
s1 = pl.Series([1.0, 2.0])
806
s2 = pl.Series([1.0, 2.1])
807
assert_series_equal(s1, s2)
808
809
def test_series_schema_fail():
810
s1 = pl.Series([1, 2], dtype=pl.Int64)
811
s2 = pl.Series([1, 2], dtype=pl.Int32)
812
assert_series_equal(s1, s2)
813
814
def test_series_data_type_fail():
815
s1 = pl.Series([1, 2])
816
s2 = [1, 2]
817
assert_series_equal(s1, s2)
818
""",
819
)
820
result = testdir.runpytest()
821
result.assert_outcomes(passed=0, failed=8)
822
stdout = "\n".join(result.outlines)
823
824
assert "polars/py-polars/polars/testing" not in stdout
825
826
# The above should catch any polars testing functions that appear in the
827
# stack trace. But we keep the following checks (for specific function
828
# names) just to double-check.
829
830
assert "def assert_series_equal" not in stdout
831
assert "def assert_series_not_equal" not in stdout
832
833
# Make sure the tests are failing for the expected reason (e.g. not because
834
# an import is missing or something like that):
835
836
assert "AssertionError: Series are different (exact value mismatch)" in stdout
837
assert "AssertionError: Series are equal" in stdout
838
assert "AssertionError: Series are different (nan value mismatch)" in stdout
839
assert "AssertionError: Series are different (dtype mismatch)" in stdout
840
assert "AssertionError: inputs are different (unexpected input types)" in stdout
841
842
843
def test_assert_series_equal_inf() -> None:
844
s1 = pl.Series([1.0, float("inf")])
845
s2 = pl.Series([1.0, float("inf")])
846
assert_series_equal(s1, s2)
847
848
s1 = pl.Series([1.0, float("-inf")])
849
s2 = pl.Series([1.0, float("-inf")])
850
assert_series_equal(s1, s2)
851
852
s1 = pl.Series([1.0, float("inf")])
853
s2 = pl.Series([float("inf"), 1.0])
854
assert_series_not_equal(s1, s2)
855
856
s1 = pl.Series([1.0, float("inf")])
857
s2 = pl.Series([1.0, float("-inf")])
858
assert_series_not_equal(s1, s2)
859
860
s1 = pl.Series([1.0, float("inf")])
861
s2 = pl.Series([1.0, 2.0])
862
assert_series_not_equal(s1, s2)
863
864
s1 = pl.Series([1.0, float("inf")])
865
s2 = pl.Series([1.0, float("nan")])
866
assert_series_not_equal(s1, s2)
867
868