Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_comparison.py
6939 views
1
from __future__ import annotations
2
3
import math
4
from contextlib import nullcontext
5
from typing import TYPE_CHECKING, Any
6
7
import pytest
8
9
import polars as pl
10
from polars.exceptions import ComputeError
11
from polars.testing import assert_frame_equal, assert_series_equal
12
13
if TYPE_CHECKING:
14
from contextlib import AbstractContextManager as ContextManager
15
16
from polars._typing import PolarsDataType
17
18
19
def test_comparison_order_null_broadcasting() -> None:
20
# see more: 8183
21
exprs = [
22
pl.col("v") < pl.col("null"),
23
pl.col("null") < pl.col("v"),
24
pl.col("v") <= pl.col("null"),
25
pl.col("null") <= pl.col("v"),
26
pl.col("v") > pl.col("null"),
27
pl.col("null") > pl.col("v"),
28
pl.col("v") >= pl.col("null"),
29
pl.col("null") >= pl.col("v"),
30
]
31
32
kwargs = {f"out{i}": e for i, e in zip(range(len(exprs)), exprs)}
33
34
# single value, hits broadcasting branch
35
df = pl.DataFrame({"v": [42], "null": [None]})
36
assert all((df.select(**kwargs).null_count() == 1).rows()[0])
37
38
# multiple values, hits default branch
39
df = pl.DataFrame({"v": [42, 42], "null": [None, None]})
40
assert all((df.select(**kwargs).null_count() == 2).rows()[0])
41
42
43
def test_comparison_nulls_single() -> None:
44
df1 = pl.DataFrame(
45
{
46
"a": pl.Series([None], dtype=pl.String),
47
"b": pl.Series([None], dtype=pl.Int64),
48
"c": pl.Series([None], dtype=pl.Boolean),
49
}
50
)
51
df2 = pl.DataFrame(
52
{
53
"a": pl.Series([None], dtype=pl.String),
54
"b": pl.Series([None], dtype=pl.Int64),
55
"c": pl.Series([None], dtype=pl.Boolean),
56
}
57
)
58
assert (df1 == df2).row(0) == (None, None, None)
59
assert (df1 != df2).row(0) == (None, None, None)
60
61
62
def test_comparison_series_expr() -> None:
63
df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})
64
65
assert_frame_equal(
66
df.select(
67
(df["a"] == pl.col("b")).alias("eq"), # False, False, True
68
(df["a"] != pl.col("b")).alias("ne"), # True, True, False
69
(df["a"] < pl.col("b")).alias("lt"), # True, False, False
70
(df["a"] <= pl.col("b")).alias("le"), # True, False, True
71
(df["a"] > pl.col("b")).alias("gt"), # False, True, False
72
(df["a"] >= pl.col("b")).alias("ge"), # False, True, True
73
),
74
pl.DataFrame(
75
{
76
"eq": [False, False, True],
77
"ne": [True, True, False],
78
"lt": [True, False, False],
79
"le": [True, False, True],
80
"gt": [False, True, False],
81
"ge": [False, True, True],
82
}
83
),
84
)
85
86
87
def test_comparison_expr_expr() -> None:
88
df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})
89
90
assert_frame_equal(
91
df.select(
92
(pl.col("a") == pl.col("b")).alias("eq"), # False, False, True
93
(pl.col("a") != pl.col("b")).alias("ne"), # True, True, False
94
(pl.col("a") < pl.col("b")).alias("lt"), # True, False, False
95
(pl.col("a") <= pl.col("b")).alias("le"), # True, False, True
96
(pl.col("a") > pl.col("b")).alias("gt"), # False, True, False
97
(pl.col("a") >= pl.col("b")).alias("ge"), # False, True, True
98
),
99
pl.DataFrame(
100
{
101
"eq": [False, False, True],
102
"ne": [True, True, False],
103
"lt": [True, False, False],
104
"le": [True, False, True],
105
"gt": [False, True, False],
106
"ge": [False, True, True],
107
}
108
),
109
)
110
111
112
def test_comparison_expr_series() -> None:
113
df = pl.DataFrame({"a": pl.Series([1, 2, 3]), "b": pl.Series([2, 1, 3])})
114
115
assert_frame_equal(
116
df.select(
117
(pl.col("a") == df["b"]).alias("eq"), # False, False, True
118
(pl.col("a") != df["b"]).alias("ne"), # True, True, False
119
(pl.col("a") < df["b"]).alias("lt"), # True, False, False
120
(pl.col("a") <= df["b"]).alias("le"), # True, False, True
121
(pl.col("a") > df["b"]).alias("gt"), # False, True, False
122
(pl.col("a") >= df["b"]).alias("ge"), # False, True, True
123
),
124
pl.DataFrame(
125
{
126
"eq": [False, False, True],
127
"ne": [True, True, False],
128
"lt": [True, False, False],
129
"le": [True, False, True],
130
"gt": [False, True, False],
131
"ge": [False, True, True],
132
}
133
),
134
)
135
136
137
def test_offset_handling_arg_where_7863() -> None:
138
df_check = pl.DataFrame({"a": [0, 1]})
139
df_check.select((pl.lit(0).append(pl.col("a")).append(0)) != 0)
140
assert (
141
df_check.select((pl.lit(0).append(pl.col("a")).append(0)) != 0)
142
.select(pl.col("literal").arg_true())
143
.item()
144
== 2
145
)
146
147
148
def test_missing_equality_on_bools() -> None:
149
df = pl.DataFrame(
150
{
151
"a": [True, None, False],
152
}
153
)
154
155
assert df.select(pl.col("a").ne_missing(True))["a"].to_list() == [False, True, True]
156
assert df.select(pl.col("a").ne_missing(False))["a"].to_list() == [
157
True,
158
True,
159
False,
160
]
161
162
163
def test_struct_equality_18870() -> None:
164
s = pl.Series([{"a": 1}, None])
165
166
# eq
167
result = s.eq(s).to_list()
168
expected = [True, None]
169
assert result == expected
170
171
# ne
172
result = s.ne(s).to_list()
173
expected = [False, None]
174
assert result == expected
175
176
# eq_missing
177
result = s.eq_missing(s).to_list()
178
expected = [True, True]
179
assert result == expected
180
181
# ne_missing
182
result = s.ne_missing(s).to_list()
183
expected = [False, False]
184
assert result == expected
185
186
187
def test_struct_nested_equality() -> None:
188
df = pl.DataFrame(
189
{
190
"a": [{"foo": 0, "bar": "1"}, {"foo": None, "bar": "1"}, None],
191
"b": [{"foo": 0, "bar": "1"}] * 3,
192
}
193
)
194
195
# eq
196
ans = df.select(pl.col("a").eq(pl.col("b")))
197
expected = pl.DataFrame({"a": [True, False, None]})
198
assert_frame_equal(ans, expected)
199
200
# ne
201
ans = df.select(pl.col("a").ne(pl.col("b")))
202
expected = pl.DataFrame({"a": [False, True, None]})
203
assert_frame_equal(ans, expected)
204
205
206
def isnan(x: Any) -> bool:
207
return isinstance(x, float) and math.isnan(x)
208
209
210
def reference_ordering_propagating(lhs: Any, rhs: Any) -> str | None:
211
# normal < nan, nan == nan, nulls propagate
212
if lhs is None or rhs is None:
213
return None
214
215
if isnan(lhs) and isnan(rhs):
216
return "="
217
218
if isnan(lhs) or lhs > rhs:
219
return ">"
220
221
if isnan(rhs) or lhs < rhs:
222
return "<"
223
224
return "="
225
226
227
def reference_ordering_missing(lhs: Any, rhs: Any) -> str:
228
# null < normal < nan, nan == nan, null == null
229
if lhs is None and rhs is None:
230
return "="
231
232
if lhs is None:
233
return "<"
234
235
if rhs is None:
236
return ">"
237
238
if isnan(lhs) and isnan(rhs):
239
return "="
240
241
if isnan(lhs) or lhs > rhs:
242
return ">"
243
244
if isnan(rhs) or lhs < rhs:
245
return "<"
246
247
return "="
248
249
250
def verify_total_ordering(
251
lhs: Any, rhs: Any, dummy: Any, ldtype: PolarsDataType, rdtype: PolarsDataType
252
) -> None:
253
ref = reference_ordering_propagating(lhs, rhs)
254
refmiss = reference_ordering_missing(lhs, rhs)
255
256
# Add dummy variable so we don't broadcast or do full-null optimization.
257
assert dummy is not None
258
df = pl.DataFrame(
259
{"l": [lhs, dummy], "r": [rhs, dummy]}, schema={"l": ldtype, "r": rdtype}
260
)
261
262
ans = df.select(
263
(pl.col("l") == pl.col("r")).alias("eq"),
264
(pl.col("l") != pl.col("r")).alias("ne"),
265
(pl.col("l") < pl.col("r")).alias("lt"),
266
(pl.col("l") <= pl.col("r")).alias("le"),
267
(pl.col("l") > pl.col("r")).alias("gt"),
268
(pl.col("l") >= pl.col("r")).alias("ge"),
269
pl.col("l").eq_missing(pl.col("r")).alias("eq_missing"),
270
pl.col("l").ne_missing(pl.col("r")).alias("ne_missing"),
271
)
272
273
ans_correct_dict = {
274
"eq": [ref and ref == "="], # "ref and X" propagates ref is None
275
"ne": [ref and ref != "="],
276
"lt": [ref and ref == "<"],
277
"le": [ref and (ref == "<" or ref == "=")],
278
"gt": [ref and ref == ">"],
279
"ge": [ref and (ref == ">" or ref == "=")],
280
"eq_missing": [refmiss == "="],
281
"ne_missing": [refmiss != "="],
282
}
283
ans_correct = pl.DataFrame(
284
ans_correct_dict, schema=dict.fromkeys(ans_correct_dict, pl.Boolean)
285
)
286
287
assert_frame_equal(ans[:1], ans_correct)
288
289
290
def verify_total_ordering_broadcast(
291
lhs: Any, rhs: Any, dummy: Any, ldtype: PolarsDataType, rdtype: PolarsDataType
292
) -> None:
293
ref = reference_ordering_propagating(lhs, rhs)
294
refmiss = reference_ordering_missing(lhs, rhs)
295
296
# Add dummy variable so we don't broadcast inherently.
297
assert dummy is not None
298
df = pl.DataFrame(
299
{"l": [lhs, dummy], "r": [rhs, dummy]}, schema={"l": ldtype, "r": rdtype}
300
)
301
302
ans_first = df.select(
303
(pl.col("l") == pl.col("r").first()).alias("eq"),
304
(pl.col("l") != pl.col("r").first()).alias("ne"),
305
(pl.col("l") < pl.col("r").first()).alias("lt"),
306
(pl.col("l") <= pl.col("r").first()).alias("le"),
307
(pl.col("l") > pl.col("r").first()).alias("gt"),
308
(pl.col("l") >= pl.col("r").first()).alias("ge"),
309
pl.col("l").eq_missing(pl.col("r").first()).alias("eq_missing"),
310
pl.col("l").ne_missing(pl.col("r").first()).alias("ne_missing"),
311
)
312
313
ans_scalar = df.select(
314
(pl.col("l") == rhs).alias("eq"),
315
(pl.col("l") != rhs).alias("ne"),
316
(pl.col("l") < rhs).alias("lt"),
317
(pl.col("l") <= rhs).alias("le"),
318
(pl.col("l") > rhs).alias("gt"),
319
(pl.col("l") >= rhs).alias("ge"),
320
(pl.col("l").eq_missing(rhs)).alias("eq_missing"),
321
(pl.col("l").ne_missing(rhs)).alias("ne_missing"),
322
)
323
324
ans_correct_dict = {
325
"eq": [ref and ref == "="], # "ref and X" propagates ref is None
326
"ne": [ref and ref != "="],
327
"lt": [ref and ref == "<"],
328
"le": [ref and (ref == "<" or ref == "=")],
329
"gt": [ref and ref == ">"],
330
"ge": [ref and (ref == ">" or ref == "=")],
331
"eq_missing": [refmiss == "="],
332
"ne_missing": [refmiss != "="],
333
}
334
ans_correct = pl.DataFrame(
335
ans_correct_dict, schema=dict.fromkeys(ans_correct_dict, pl.Boolean)
336
)
337
338
assert_frame_equal(ans_first[:1], ans_correct)
339
assert_frame_equal(ans_scalar[:1], ans_correct)
340
341
342
INTERESTING_FLOAT_VALUES = [
343
0.0,
344
-0.0,
345
-1.0,
346
1.0,
347
-float("nan"),
348
float("nan"),
349
-float("inf"),
350
float("inf"),
351
None,
352
]
353
354
355
@pytest.mark.slow
356
@pytest.mark.parametrize("lhs", INTERESTING_FLOAT_VALUES)
357
@pytest.mark.parametrize("rhs", INTERESTING_FLOAT_VALUES)
358
def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> None:
359
verify_total_ordering(lhs, rhs, 0.0, pl.Float32, pl.Float32)
360
verify_total_ordering(lhs, rhs, 0.0, pl.Float64, pl.Float32)
361
context: pytest.WarningsRecorder | ContextManager[None] = (
362
pytest.warns(UserWarning) if rhs is None else nullcontext()
363
)
364
with context:
365
verify_total_ordering_broadcast(lhs, rhs, 0.0, pl.Float32, pl.Float32)
366
verify_total_ordering_broadcast(lhs, rhs, 0.0, pl.Float64, pl.Float32)
367
368
369
INTERESTING_STRING_VALUES = [
370
"",
371
"foo",
372
"bar",
373
"fooo",
374
"fooooooooooo",
375
"foooooooooooo",
376
"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooom",
377
"foooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo",
378
"fooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooop",
379
None,
380
]
381
382
383
@pytest.mark.slow
384
@pytest.mark.parametrize("lhs", INTERESTING_STRING_VALUES)
385
@pytest.mark.parametrize("rhs", INTERESTING_STRING_VALUES)
386
def test_total_ordering_string_series(lhs: str | None, rhs: str | None) -> None:
387
verify_total_ordering(lhs, rhs, "", pl.String, pl.String)
388
context: pytest.WarningsRecorder | ContextManager[None] = (
389
pytest.warns(UserWarning) if rhs is None else nullcontext()
390
)
391
with context:
392
verify_total_ordering_broadcast(lhs, rhs, "", pl.String, pl.String)
393
394
395
@pytest.mark.slow
396
@pytest.mark.parametrize("lhs", INTERESTING_STRING_VALUES)
397
@pytest.mark.parametrize("rhs", INTERESTING_STRING_VALUES)
398
@pytest.mark.parametrize("fresh_cat", [False, True])
399
def test_total_ordering_cat_series(
400
lhs: str | None, rhs: str | None, fresh_cat: bool
401
) -> None:
402
if fresh_cat:
403
c = [pl.Categorical(pl.Categories.random()) for _ in range(6)]
404
else:
405
c = [pl.Categorical() for _ in range(6)]
406
verify_total_ordering(lhs, rhs, "", c[0], c[0])
407
verify_total_ordering(lhs, rhs, "", pl.String, c[1])
408
verify_total_ordering(lhs, rhs, "", c[2], pl.String)
409
context: pytest.WarningsRecorder | ContextManager[None] = (
410
pytest.warns(UserWarning) if rhs is None else nullcontext()
411
)
412
with context:
413
verify_total_ordering_broadcast(lhs, rhs, "", c[3], c[3])
414
verify_total_ordering_broadcast(lhs, rhs, "", pl.String, c[4])
415
verify_total_ordering_broadcast(lhs, rhs, "", c[5], pl.String)
416
417
418
@pytest.mark.slow
419
@pytest.mark.parametrize("str_lhs", INTERESTING_STRING_VALUES)
420
@pytest.mark.parametrize("str_rhs", INTERESTING_STRING_VALUES)
421
def test_total_ordering_binary_series(str_lhs: str | None, str_rhs: str | None) -> None:
422
lhs = None if str_lhs is None else str_lhs.encode("utf-8")
423
rhs = None if str_rhs is None else str_rhs.encode("utf-8")
424
verify_total_ordering(lhs, rhs, b"", pl.Binary, pl.Binary)
425
context: pytest.WarningsRecorder | ContextManager[None] = (
426
pytest.warns(UserWarning) if rhs is None else nullcontext()
427
)
428
with context:
429
verify_total_ordering_broadcast(lhs, rhs, b"", pl.Binary, pl.Binary)
430
431
432
@pytest.mark.parametrize("lhs", [None, False, True])
433
@pytest.mark.parametrize("rhs", [None, False, True])
434
def test_total_ordering_bool_series(lhs: bool | None, rhs: bool | None) -> None:
435
verify_total_ordering(lhs, rhs, False, pl.Boolean, pl.Boolean)
436
context: pytest.WarningsRecorder | ContextManager[None] = (
437
pytest.warns(UserWarning) if rhs is None else nullcontext()
438
)
439
with context:
440
verify_total_ordering_broadcast(lhs, rhs, False, pl.Boolean, pl.Boolean)
441
442
443
def test_cat_compare_with_bool() -> None:
444
data = pl.DataFrame([pl.Series("col1", ["a", "b"], dtype=pl.Categorical)])
445
446
with pytest.raises(ComputeError, match="cannot compare categorical with bool"):
447
data.filter(pl.col("col1") == True) # noqa: E712
448
449
450
def test_schema_ne_missing_9256() -> None:
451
df = pl.DataFrame({"a": [0, 1, None], "b": [True, False, True]})
452
453
assert df.select(pl.col("a").ne_missing(0).or_(pl.col("b")))["a"].all()
454
455
456
def test_nested_binary_literal_super_type_12227() -> None:
457
# The `.alias` is important here to trigger the bug.
458
result = pl.select(x=1).select((pl.lit(0) + ((pl.col("x") > 0) * 0.1)).alias("x"))
459
assert result.item() == 0.1
460
461
result = pl.select((pl.lit(0) + (pl.lit(0) == pl.lit(0)) * pl.lit(0.1)) + pl.lit(0))
462
assert result.item() == 0.1
463
464
465
def test_struct_broadcasting_comparison() -> None:
466
df = pl.DataFrame({"foo": [{"a": 1}, {"a": 2}, {"a": 1}]})
467
assert df.select(eq=pl.col.foo == pl.col.foo.last()).to_dict(as_series=False) == {
468
"eq": [True, False, True]
469
}
470
471
472
@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 1)])
473
def test_compare_list_broadcast_empty_first_chunk_20165(dtype: pl.DataType) -> None:
474
s = pl.concat(2 * [pl.Series([[1]], dtype=dtype)]).filter([False, True])
475
476
assert s.len() == 1
477
assert s.n_chunks() == 2
478
479
assert_series_equal(
480
pl.select(pl.lit(pl.Series([[1], [2]]), dtype=dtype) == pl.lit(s)).to_series(),
481
pl.Series([True, False]),
482
)
483
484