Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/series/test_equals.py
8430 views
1
from collections.abc import Callable
2
from datetime import datetime
3
4
import pytest
5
6
import polars as pl
7
from polars.testing import assert_series_equal
8
9
10
def test_equals() -> None:
11
s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)
12
s2 = pl.Series("a", [1, 2, None], pl.Int64)
13
14
assert s1.equals(s2) is True
15
assert s1.equals(s2, check_dtypes=True) is False
16
assert s1.equals(s2, null_equal=False) is False
17
18
df = pl.DataFrame(
19
{"dtm": [datetime(2222, 2, 22, 22, 22, 22)]},
20
schema_overrides={"dtm": pl.Datetime(time_zone="UTC")},
21
).with_columns(
22
s3=pl.col("dtm").dt.convert_time_zone("Europe/London"),
23
s4=pl.col("dtm").dt.convert_time_zone("Asia/Tokyo"),
24
)
25
s3 = df["s3"].rename("b")
26
s4 = df["s4"].rename("b")
27
28
assert s3.equals(s4) is False
29
assert s3.equals(s4, check_dtypes=True) is False
30
assert s3.equals(s4, null_equal=False) is False
31
assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True
32
33
with pytest.raises(
34
TypeError,
35
match=r"expected `other` to be a 'Series'.* not 'DataFrame'",
36
):
37
s1.equals(pl.DataFrame(s2), check_names=False) # type: ignore[arg-type]
38
39
with pytest.raises(
40
TypeError,
41
match=r"expected `other` to be a 'Series'.* not 'LazyFrame'",
42
):
43
s1.equals(pl.DataFrame(s2).lazy(), check_names=False) # type: ignore[arg-type]
44
45
s5 = pl.Series("a", [1, 2, 3])
46
47
class DummySeriesSubclass(pl.Series):
48
pass
49
50
assert s5.equals(DummySeriesSubclass(s5)) is True
51
52
53
def test_series_equals_check_names() -> None:
54
s1 = pl.Series("foo", [1, 2, 3])
55
s2 = pl.Series("bar", [1, 2, 3])
56
assert s1.equals(s2) is True
57
assert s1.equals(s2, check_names=True) is False
58
59
60
def test_eq_list_cmp_list() -> None:
61
s = pl.Series([[1], [1, 2]])
62
result = s == [1, 2]
63
expected = pl.Series([False, True])
64
assert_series_equal(result, expected)
65
66
67
def test_eq_list_cmp_int() -> None:
68
s = pl.Series([[1], [1, 2]])
69
with pytest.raises(
70
NotImplementedError,
71
match=r"Series of type List\(Int64\) does not have eq operator",
72
):
73
s == 1 # noqa: B015
74
75
76
def test_eq_array_cmp_list() -> None:
77
s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))
78
result = s == [1, 2]
79
expected = pl.Series([False, True])
80
assert_series_equal(result, expected)
81
82
83
def test_eq_array_cmp_int() -> None:
84
s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))
85
with pytest.raises(
86
NotImplementedError,
87
match=r"Series of type Array\(Int16, shape=\(2,\)\) does not have eq operator",
88
):
89
s == 1 # noqa: B015
90
91
92
def test_eq_list() -> None:
93
s = pl.Series([1, 1])
94
95
result = s == [1, 2]
96
expected = pl.Series([True, False])
97
assert_series_equal(result, expected)
98
99
result = s == 1
100
expected = pl.Series([True, True])
101
assert_series_equal(result, expected)
102
103
104
def test_eq_missing_expr() -> None:
105
s = pl.Series([1, None])
106
result = s.eq_missing(pl.lit(1))
107
108
assert isinstance(result, pl.Expr)
109
result_evaluated = pl.select(result).to_series()
110
expected = pl.Series([True, False])
111
assert_series_equal(result_evaluated, expected)
112
113
114
def test_ne_missing_expr() -> None:
115
s = pl.Series([1, None])
116
result = s.ne_missing(pl.lit(1))
117
118
assert isinstance(result, pl.Expr)
119
result_evaluated = pl.select(result).to_series()
120
expected = pl.Series([False, True])
121
assert_series_equal(result_evaluated, expected)
122
123
124
def test_series_equals_strict_deprecated() -> None:
125
s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64)
126
s2 = pl.Series("a", [1, 2, None], pl.Int64)
127
with pytest.deprecated_call():
128
assert not s1.equals(s2, strict=True) # type: ignore[call-arg]
129
130
131
@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])
132
@pytest.mark.parametrize(
133
("cmp_eq", "cmp_ne"),
134
[
135
# We parametrize the comparison sides as the impl looks like this:
136
# match (left.len(), right.len()) {
137
# (1, _) => ...,
138
# (_, 1) => ...,
139
# (_, _) => ...,
140
# }
141
(pl.Series.eq, pl.Series.ne),
142
(
143
lambda a, b: pl.Series.eq(b, a),
144
lambda a, b: pl.Series.ne(b, a),
145
),
146
],
147
)
148
def test_eq_lists_arrays(
149
dtype: pl.DataType,
150
cmp_eq: Callable[[pl.Series, pl.Series], pl.Series],
151
cmp_ne: Callable[[pl.Series, pl.Series], pl.Series],
152
) -> None:
153
# Broadcast NULL
154
assert_series_equal(
155
cmp_eq(
156
pl.Series([None], dtype=dtype),
157
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
158
),
159
pl.Series([None, None, None], dtype=pl.Boolean),
160
)
161
162
assert_series_equal(
163
cmp_ne(
164
pl.Series([None], dtype=dtype),
165
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
166
),
167
pl.Series([None, None, None], dtype=pl.Boolean),
168
)
169
170
# Non-broadcast full-NULL
171
assert_series_equal(
172
cmp_eq(
173
pl.Series(3 * [None], dtype=dtype),
174
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
175
),
176
pl.Series([None, None, None], dtype=pl.Boolean),
177
)
178
179
assert_series_equal(
180
cmp_ne(
181
pl.Series(3 * [None], dtype=dtype),
182
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
183
),
184
pl.Series([None, None, None], dtype=pl.Boolean),
185
)
186
187
# Broadcast valid
188
assert_series_equal(
189
cmp_eq(
190
pl.Series([[1, None]], dtype=dtype),
191
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
192
),
193
pl.Series([None, True, False], dtype=pl.Boolean),
194
)
195
196
assert_series_equal(
197
cmp_ne(
198
pl.Series([[1, None]], dtype=dtype),
199
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
200
),
201
pl.Series([None, False, True], dtype=pl.Boolean),
202
)
203
204
# Non-broadcast mixed
205
assert_series_equal(
206
cmp_eq(
207
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
208
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
209
),
210
pl.Series([None, False, True], dtype=pl.Boolean),
211
)
212
213
assert_series_equal(
214
cmp_ne(
215
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
216
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
217
),
218
pl.Series([None, True, False], dtype=pl.Boolean),
219
)
220
221
222
@pytest.mark.parametrize("dtype", [pl.List(pl.Int64), pl.Array(pl.Int64, 2)])
223
@pytest.mark.parametrize(
224
("cmp_eq_missing", "cmp_ne_missing"),
225
[
226
(pl.Series.eq_missing, pl.Series.ne_missing),
227
(
228
lambda a, b: pl.Series.eq_missing(b, a),
229
lambda a, b: pl.Series.ne_missing(b, a),
230
),
231
],
232
)
233
def test_eq_missing_lists_arrays_19153(
234
dtype: pl.DataType,
235
cmp_eq_missing: Callable[[pl.Series, pl.Series], pl.Series],
236
cmp_ne_missing: Callable[[pl.Series, pl.Series], pl.Series],
237
) -> None:
238
def assert_series_equal_wrap(
239
left: pl.Series,
240
right: pl.Series,
241
) -> None:
242
# `assert_series_equal` also uses `ne_missing` underneath so we have
243
# some extra checks here to be sure.
244
assert_series_equal(left, right)
245
assert left.to_list() == right.to_list()
246
assert left.null_count() == 0
247
assert right.null_count() == 0
248
249
# Broadcast NULL
250
assert_series_equal_wrap(
251
cmp_eq_missing(
252
pl.Series([None], dtype=dtype),
253
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
254
),
255
pl.Series([True, False, False]),
256
)
257
258
assert_series_equal_wrap(
259
cmp_ne_missing(
260
pl.Series([None], dtype=dtype),
261
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
262
),
263
pl.Series([False, True, True]),
264
)
265
266
# Non-broadcast full-NULL
267
assert_series_equal_wrap(
268
cmp_eq_missing(
269
pl.Series(3 * [None], dtype=dtype),
270
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
271
),
272
pl.Series([True, False, False]),
273
)
274
275
assert_series_equal_wrap(
276
cmp_ne_missing(
277
pl.Series(3 * [None], dtype=dtype),
278
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
279
),
280
pl.Series([False, True, True]),
281
)
282
283
# Broadcast valid
284
assert_series_equal_wrap(
285
cmp_eq_missing(
286
pl.Series([[1, None]], dtype=dtype),
287
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
288
),
289
pl.Series([False, True, False]),
290
)
291
292
assert_series_equal_wrap(
293
cmp_ne_missing(
294
pl.Series([[1, None]], dtype=dtype),
295
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
296
),
297
pl.Series([True, False, True]),
298
)
299
300
# Non-broadcast mixed
301
assert_series_equal_wrap(
302
cmp_eq_missing(
303
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
304
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
305
),
306
pl.Series([True, False, True]),
307
)
308
309
assert_series_equal_wrap(
310
cmp_ne_missing(
311
pl.Series([None, [1, 1], [1, 1]], dtype=dtype),
312
pl.Series([None, [1, None], [1, 1]], dtype=dtype),
313
),
314
pl.Series([False, True, False]),
315
)
316
317
318
def test_equals_nested_null_categorical_14875() -> None:
319
dtype = pl.List(pl.Struct({"cat": pl.Categorical}))
320
s = pl.Series([[{"cat": None}]], dtype=dtype)
321
assert s.equals(s)
322
323