Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/series/test_equals.py
6939 views
1
from datetime import datetime
2
from typing import Callable
3
4
import pytest
5
6
import polars as pl
7
from polars.testing import assert_series_equal
8
9
10
def test_equals() -> None:
11
s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)
12
s2 = pl.Series("a", [1, 2, None], pl.Int64)
13
14
assert s1.equals(s2) is True
15
assert s1.equals(s2, check_dtypes=True) is False
16
assert s1.equals(s2, null_equal=False) is False
17
18
df = pl.DataFrame(
19
{"dtm": [datetime(2222, 2, 22, 22, 22, 22)]},
20
schema_overrides={"dtm": pl.Datetime(time_zone="UTC")},
21
).with_columns(
22
s3=pl.col("dtm").dt.convert_time_zone("Europe/London"),
23
s4=pl.col("dtm").dt.convert_time_zone("Asia/Tokyo"),
24
)
25
s3 = df["s3"].rename("b")
26
s4 = df["s4"].rename("b")
27
28
assert s3.equals(s4) is False
29
assert s3.equals(s4, check_dtypes=True) is False
30
assert s3.equals(s4, null_equal=False) is False
31
assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True
32
33
with pytest.raises(
34
TypeError,
35
match="expected `other` to be a 'Series'.* not 'DataFrame'",
36
):
37
s1.equals(pl.DataFrame(s2), check_names=False) # type: ignore[arg-type]
38
39
with pytest.raises(
40
TypeError,
41
match="expected `other` to be a 'Series'.* not 'LazyFrame'",
42
):
43
s1.equals(pl.DataFrame(s2).lazy(), check_names=False) # type: ignore[arg-type]
44
45
s5 = pl.Series("a", [1, 2, 3])
46
47
class DummySeriesSubclass(pl.Series):
48
pass
49
50
assert s5.equals(DummySeriesSubclass(s5)) is True
51
52
53
def test_series_equals_check_names() -> None:
54
s1 = pl.Series("foo", [1, 2, 3])
55
s2 = pl.Series("bar", [1, 2, 3])
56
assert s1.equals(s2) is True
57
assert s1.equals(s2, check_names=True) is False
58
59
60
def test_eq_list_cmp_list() -> None:
61
s = pl.Series([[1], [1, 2]])
62
result = s == [1, 2]
63
expected = pl.Series([False, True])
64
assert_series_equal(result, expected)
65
66
67
def test_eq_list_cmp_int() -> None:
68
s = pl.Series([[1], [1, 2]])
69
with pytest.raises(
70
NotImplementedError,
71
match=r"Series of type List\(Int64\) does not have eq operator",
72
):
73
s == 1 # noqa: B015
74
75
76
def test_eq_array_cmp_list() -> None:
77
s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))
78
result = s == [1, 2]
79
expected = pl.Series([False, True])
80
assert_series_equal(result, expected)
81
82
83
def test_eq_array_cmp_int() -> None:
84
s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))
85
with pytest.raises(
86
NotImplementedError,
87
match=r"Series of type Array\(Int16, shape=\(2,\)\) does not have eq operator",
88
):
89
s == 1 # noqa: B015
90
91
92
def test_eq_list() -> None:
93
s = pl.Series([1, 1])
94
95
result = s == [1, 2]
96
expected = pl.Series([True, False])
97
assert_series_equal(result, expected)
98
99
result = s == 1
100
expected = pl.Series([True, True])
101
assert_series_equal(result, expected)
102
103
104
def test_eq_missing_expr() -> None:
105
s = pl.Series([1, None])
106
result = s.eq_missing(pl.lit(1))
107
108
assert isinstance(result, pl.Expr)
109
result_evaluated = pl.select(result).to_series()
110
expected = pl.Series([True, False])
111
assert_series_equal(result_evaluated, expected)
112
113
114
def test_ne_missing_expr() -> None:
115
s = pl.Series([1, None])
116
result = s.ne_missing(pl.lit(1))
117
118
assert isinstance(result, pl.Expr)
119
result_evaluated = pl.select(result).to_series()
120
expected = pl.Series([False, True])
121
assert_series_equal(result_evaluated, expected)
122
123
124
def test_series_equals_strict_deprecated() -> None:
125
s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)
126
s2 = pl.Series("a", [1, 2, None], pl.Int64)
127
with pytest.deprecated_call():
128
assert not s1.equals(s2, strict=True) # type: ignore[call-arg]
129
130
131
@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])
132
@pytest.mark.parametrize(
133
("cmp_eq", "cmp_ne"),
134
[
135
# We parametrize the comparison sides as the impl looks like this:
136
# match (left.len(), right.len()) {
137
# (1, _) => ...,
138
# (_, 1) => ...,
139
# (_, _) => ...,
140
# }
141
(pl.Series.eq, pl.Series.ne),
142
(
143
lambda a, b: pl.Series.eq(b, a),
144
lambda a, b: pl.Series.ne(b, a),
145
),
146
],
147
)
148
def test_eq_lists_arrays(
149
dtype: pl.DataType,
150
cmp_eq: Callable[[pl.Series, pl.Series], pl.Series],
151
cmp_ne: Callable[[pl.Series, pl.Series], pl.Series],
152
) -> None:
153
# Broadcast NULL
154
assert_series_equal(
155
cmp_eq(
156
pl.Series([None], dtype=dtype),
157
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
158
),
159
pl.Series([None, None, None], dtype=pl.Boolean),
160
)
161
162
assert_series_equal(
163
cmp_ne(
164
pl.Series([None], dtype=dtype),
165
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
166
),
167
pl.Series([None, None, None], dtype=pl.Boolean),
168
)
169
170
# Non-broadcast full-NULL
171
assert_series_equal(
172
cmp_eq(
173
pl.Series(3 * [None], dtype=dtype),
174
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
175
),
176
pl.Series([None, None, None], dtype=pl.Boolean),
177
)
178
179
assert_series_equal(
180
cmp_ne(
181
pl.Series(3 * [None], dtype=dtype),
182
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
183
),
184
pl.Series([None, None, None], dtype=pl.Boolean),
185
)
186
187
# Broadcast valid
188
assert_series_equal(
189
cmp_eq(
190
pl.Series([[1, None]], dtype=dtype),
191
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
192
),
193
pl.Series([None, True, False], dtype=pl.Boolean),
194
)
195
196
assert_series_equal(
197
cmp_ne(
198
pl.Series([[1, None]], dtype=dtype),
199
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
200
),
201
pl.Series([None, False, True], dtype=pl.Boolean),
202
)
203
204
# Non-broadcast mixed
205
assert_series_equal(
206
cmp_eq(
207
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
208
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
209
),
210
pl.Series([None, False, True], dtype=pl.Boolean),
211
)
212
213
assert_series_equal(
214
cmp_ne(
215
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
216
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
217
),
218
pl.Series([None, True, False], dtype=pl.Boolean),
219
)
220
221
222
@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])
223
@pytest.mark.parametrize(
224
("cmp_eq_missing", "cmp_ne_missing"),
225
[
226
(pl.Series.eq_missing, pl.Series.ne_missing),
227
(
228
lambda a, b: pl.Series.eq_missing(b, a),
229
lambda a, b: pl.Series.ne_missing(b, a),
230
),
231
],
232
)
233
def test_eq_missing_lists_arrays_19153(
234
dtype: pl.DataType,
235
cmp_eq_missing: Callable[[pl.Series, pl.Series], pl.Series],
236
cmp_ne_missing: Callable[[pl.Series, pl.Series], pl.Series],
237
) -> None:
238
def assert_series_equal(
239
left: pl.Series,
240
right: pl.Series,
241
*,
242
assert_series_equal_impl: Callable[[pl.Series, pl.Series], None] = globals()[
243
"assert_series_equal"
244
],
245
) -> None:
246
# `assert_series_equal` also uses `ne_missing` underneath so we have
247
# some extra checks here to be sure.
248
assert_series_equal_impl(left, right)
249
assert left.to_list() == right.to_list()
250
assert left.null_count() == 0
251
assert right.null_count() == 0
252
253
# Broadcast NULL
254
assert_series_equal(
255
cmp_eq_missing(
256
pl.Series([None], dtype=dtype),
257
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
258
),
259
pl.Series([True, False, False]),
260
)
261
262
assert_series_equal(
263
cmp_ne_missing(
264
pl.Series([None], dtype=dtype),
265
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
266
),
267
pl.Series([False, True, True]),
268
)
269
270
# Non-broadcast full-NULL
271
assert_series_equal(
272
cmp_eq_missing(
273
pl.Series(3 * [None], dtype=dtype),
274
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
275
),
276
pl.Series([True, False, False]),
277
)
278
279
assert_series_equal(
280
cmp_ne_missing(
281
pl.Series(3 * [None], dtype=dtype),
282
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
283
),
284
pl.Series([False, True, True]),
285
)
286
287
# Broadcast valid
288
assert_series_equal(
289
cmp_eq_missing(
290
pl.Series([[1, None]], dtype=dtype),
291
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
292
),
293
pl.Series([False, True, False]),
294
)
295
296
assert_series_equal(
297
cmp_ne_missing(
298
pl.Series([[1, None]], dtype=dtype),
299
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
300
),
301
pl.Series([True, False, True]),
302
)
303
304
# Non-broadcast mixed
305
assert_series_equal(
306
cmp_eq_missing(
307
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
308
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
309
),
310
pl.Series([True, False, True]),
311
)
312
313
assert_series_equal(
314
cmp_ne_missing(
315
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
316
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
317
),
318
pl.Series([False, True, False]),
319
)
320
321
322
def test_equals_nested_null_categorical_14875() -> None:
323
dtype = pl.List(pl.Struct({"cat": pl.Categorical}))
324
s = pl.Series([[{"cat": None}]], dtype=dtype)
325
assert s.equals(s)
326
327