Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_index_of.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time, timedelta
4
from decimal import Decimal
5
from typing import TYPE_CHECKING, Any
6
7
import numpy as np
8
import pytest
9
from hypothesis import example, given
10
from hypothesis import strategies as st
11
12
import polars as pl
13
from polars.exceptions import InvalidOperationError
14
from polars.testing import assert_frame_equal
15
from polars.testing.parametric import series
16
17
if TYPE_CHECKING:
18
from polars._typing import IntoExpr
19
20
21
def isnan(value: object) -> bool:
22
if isinstance(value, int):
23
return False
24
if not isinstance(value, (np.number, float)):
25
return False
26
return np.isnan(value) # type: ignore[no-any-return]
27
28
29
def assert_index_of(
30
series: pl.Series,
31
value: IntoExpr,
32
convert_to_literal: bool = False,
33
) -> None:
34
"""``Series.index_of()`` returns the index, or ``None`` if it can't be found."""
35
if isnan(value):
36
expected_index = None
37
for i, o in enumerate(series.to_list()):
38
if o is not None and np.isnan(o):
39
expected_index = i
40
break
41
else:
42
try:
43
expected_index = series.to_list().index(value)
44
except ValueError:
45
expected_index = None
46
if expected_index == -1:
47
expected_index = None
48
49
if convert_to_literal:
50
value = pl.lit(value, dtype=series.dtype)
51
52
# Eager API:
53
assert series.index_of(value) == expected_index
54
# Lazy API:
55
assert pl.LazyFrame({"series": series}).select(
56
pl.col("series").index_of(value)
57
).collect().get_column("series").to_list() == [expected_index]
58
59
60
@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])
61
def test_float(dtype: pl.DataType) -> None:
62
values = [1.5, np.nan, np.inf, 3.0, None, -np.inf, 0.0, -0.0, -np.nan]
63
series = pl.Series(values, dtype=dtype)
64
sorted_series_asc = series.sort(descending=False)
65
sorted_series_desc = series.sort(descending=True)
66
chunked_series = pl.concat([pl.Series([1, 7], dtype=dtype), series], rechunk=False)
67
68
extra_values = [
69
np.int8(3),
70
np.int64(2**42),
71
np.float64(1.5),
72
np.float32(1.5),
73
np.float32(2**37),
74
np.float64(2**100),
75
]
76
for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:
77
for value in values:
78
assert_index_of(s, value, convert_to_literal=True)
79
assert_index_of(s, value, convert_to_literal=False)
80
for value in extra_values: # type: ignore[assignment]
81
assert_index_of(s, value)
82
83
# Explicitly check some extra-tricky edge cases:
84
assert series.index_of(-np.nan) == 1 # -np.nan should match np.nan
85
assert series.index_of(-0.0) == 6 # -0.0 should match 0.0
86
87
88
def test_null() -> None:
89
series = pl.Series([None, None], dtype=pl.Null)
90
assert_index_of(series, None)
91
92
93
def test_empty() -> None:
94
series = pl.Series([], dtype=pl.Null)
95
assert_index_of(series, None)
96
series = pl.Series([], dtype=pl.Int64)
97
assert_index_of(series, None)
98
assert_index_of(series, 12)
99
assert_index_of(series.sort(descending=True), 12)
100
assert_index_of(series.sort(descending=False), 12)
101
102
103
@pytest.mark.parametrize(
104
"dtype",
105
[
106
pl.Int8,
107
pl.Int16,
108
pl.Int32,
109
pl.Int64,
110
pl.UInt8,
111
pl.UInt16,
112
pl.UInt32,
113
pl.UInt64,
114
pl.Int128,
115
],
116
)
117
def test_integer(dtype: pl.DataType) -> None:
118
values = [
119
51,
120
3,
121
None,
122
4,
123
pl.select(dtype.max()).item(), # type: ignore[attr-defined]
124
pl.select(dtype.min()).item(), # type: ignore[attr-defined]
125
]
126
series = pl.Series(values, dtype=dtype)
127
sorted_series_asc = series.sort(descending=False)
128
sorted_series_desc = series.sort(descending=True)
129
chunked_series = pl.concat(
130
[pl.Series([100, 7], dtype=dtype), series], rechunk=False
131
)
132
133
extra_values = [pl.select(v).item() for v in [dtype.max() - 1, dtype.min() + 1]] # type: ignore[attr-defined]
134
for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:
135
value: IntoExpr
136
for value in values:
137
assert_index_of(s, value, convert_to_literal=True)
138
assert_index_of(s, value, convert_to_literal=False)
139
for value in extra_values:
140
assert_index_of(s, value, convert_to_literal=True)
141
assert_index_of(s, value, convert_to_literal=False)
142
143
# Can't cast floats:
144
for f in [np.float32(3.1), np.float64(3.1), 50.9]:
145
with pytest.raises(InvalidOperationError, match="cannot cast lossless"):
146
s.index_of(f) # type: ignore[arg-type]
147
148
149
def test_groupby() -> None:
150
df = pl.DataFrame(
151
{"label": ["a", "b", "a", "b", "a", "b"], "value": [10, 3, 20, 2, 40, 20]}
152
)
153
expected = pl.DataFrame(
154
{"label": ["a", "b"], "value": [1, 2]},
155
schema={"label": pl.String, "value": pl.UInt32},
156
)
157
assert_frame_equal(
158
df.group_by("label", maintain_order=True).agg(pl.col("value").index_of(20)),
159
expected,
160
)
161
assert_frame_equal(
162
df.lazy()
163
.group_by("label", maintain_order=True)
164
.agg(pl.col("value").index_of(20))
165
.collect(),
166
expected,
167
)
168
169
170
LISTS_STRATEGY = st.lists(
171
st.one_of(st.none(), st.integers(min_value=10, max_value=50)), max_size=10
172
)
173
174
175
@given(
176
list1=LISTS_STRATEGY,
177
list2=LISTS_STRATEGY,
178
list3=LISTS_STRATEGY,
179
)
180
# The examples are cases where this test previously caught bugs:
181
@example([], [], [None])
182
@pytest.mark.slow
183
def test_randomized(
184
list1: list[int | None], list2: list[int | None], list3: list[int | None]
185
) -> None:
186
series = pl.concat(
187
[pl.Series(values, dtype=pl.Int8) for values in [list1, list2, list3]],
188
rechunk=False,
189
)
190
sorted_series = series.sort(descending=False)
191
sorted_series2 = series.sort(descending=True)
192
193
# Values are between 10 and 50, plus add None and max/min range values:
194
for i in set(range(10, 51)) | {-128, 127, None}:
195
assert_index_of(series, i)
196
assert_index_of(sorted_series, i)
197
assert_index_of(sorted_series2, i)
198
199
200
ENUM = pl.Enum(["a", "b", "c"])
201
202
203
@pytest.mark.parametrize(
204
("series", "extra_values", "sortable"),
205
[
206
(pl.Series(["abc", None, "bb"]), ["", "🚲"], True),
207
(pl.Series([True, None, False, True, False]), [], True),
208
(
209
pl.Series([datetime(1997, 12, 31), datetime(1996, 1, 1)]),
210
[datetime(2023, 12, 12, 16, 12, 39)],
211
True,
212
),
213
(
214
pl.Series([date(1997, 12, 31), None, date(1996, 1, 1)]),
215
[date(2023, 12, 12)],
216
True,
217
),
218
(
219
pl.Series([time(16, 12, 31), None, time(11, 10, 53)]),
220
[time(11, 12, 16)],
221
True,
222
),
223
(
224
pl.Series(
225
[timedelta(hours=12), None, timedelta(minutes=3)],
226
),
227
[timedelta(minutes=17)],
228
True,
229
),
230
(pl.Series([[1, 2], None, [4, 5], [6], [None, 3, 5]]), [[5, 7], []], True),
231
(
232
pl.Series([[[1, 2]], None, [[4, 5]], [[6]], [[None, 3, 5]], [None]]),
233
[[[5, 7]], []],
234
True,
235
),
236
(
237
pl.Series([[1, 2], None, [4, 5], [None, 3]], dtype=pl.Array(pl.Int64(), 2)),
238
[[5, 7], [None, None]],
239
True,
240
),
241
(
242
pl.Series(
243
[[[1, 2]], [None], [[4, 5]], None, [[None, 3]]],
244
dtype=pl.Array(pl.Array(pl.Int64(), 2), 1),
245
),
246
[[[5, 7]], [[None, None]]],
247
True,
248
),
249
(
250
pl.Series(
251
[{"a": 1, "b": 2}, None, {"a": 3, "b": 4}, {"a": None, "b": 2}],
252
dtype=pl.Struct({"a": pl.Int64(), "b": pl.Int64()}),
253
),
254
[{"a": 7, "b": None}, {"a": 6, "b": 4}],
255
False,
256
),
257
(pl.Series([b"abc", None, b"xxx"]), [b"\x0025"], True),
258
(pl.Series([Decimal(12), None, Decimal(3)]), [Decimal(4)], True),
259
],
260
)
261
def test_other_types(
262
series: pl.Series, extra_values: list[Any], sortable: bool
263
) -> None:
264
expected_values = series.to_list()
265
series_variants = [series, series.drop_nulls()]
266
if sortable:
267
series_variants.extend(
268
[
269
series.sort(descending=False),
270
series.sort(descending=True),
271
]
272
)
273
for s in series_variants:
274
for value in expected_values:
275
assert_index_of(s, value, convert_to_literal=True)
276
assert_index_of(s, value, convert_to_literal=False)
277
# Extra values may not be expressible as literal of correct dtype, so
278
# don't try:
279
for value in extra_values:
280
assert_index_of(s, value)
281
282
283
# Before the output type would be list[idx-type] when no item was found
284
def test_non_found_correct_type() -> None:
285
df = pl.DataFrame(
286
[
287
pl.Series("a", [0, 1], pl.Int32),
288
pl.Series("b", [1, 2], pl.Int32),
289
]
290
)
291
292
assert_frame_equal(
293
df.group_by("a", maintain_order=True).agg(pl.col.b.index_of(1)),
294
pl.DataFrame({"a": [0, 1], "b": [0, None]}),
295
check_dtypes=False,
296
)
297
298
299
def test_error_on_multiple_values() -> None:
300
with pytest.raises(
301
pl.exceptions.InvalidOperationError,
302
match="needle of `index_of` can only contain",
303
):
304
pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3]))
305
306
307
@pytest.mark.parametrize(
308
"convert_to_literal",
309
[
310
True,
311
False,
312
],
313
)
314
def test_enum(convert_to_literal: bool) -> None:
315
series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"]))
316
expected_values = series.to_list()
317
for s in [
318
series,
319
series.drop_nulls(),
320
series.sort(descending=False),
321
series.sort(descending=True),
322
]:
323
for value in expected_values:
324
assert_index_of(s, value, convert_to_literal=convert_to_literal)
325
326
327
@pytest.mark.parametrize(
328
"convert_to_literal",
329
[True, False],
330
)
331
def test_categorical(convert_to_literal: bool) -> None:
332
series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)
333
expected_values = series.to_list()
334
for s in [
335
series,
336
series.drop_nulls(),
337
series.sort(descending=False),
338
series.sort(descending=True),
339
]:
340
for value in expected_values:
341
assert_index_of(s, value, convert_to_literal=convert_to_literal)
342
343
344
@pytest.mark.parametrize("value", [0, 0.1])
345
def test_categorical_wrong_type_keys_dont_work(value: int | float) -> None:
346
series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)
347
msg = "cannot cast lossless"
348
with pytest.raises(InvalidOperationError, match=msg):
349
series.index_of(value)
350
df = pl.DataFrame({"s": series})
351
with pytest.raises(InvalidOperationError, match=msg):
352
df.select(pl.col("s").index_of(value))
353
354
355
@given(s=series(name="s", allow_chunks=True, max_size=10))
356
def test_index_of_null_parametric(s: pl.Series) -> None:
357
idx_null = s.index_of(None)
358
if s.len() == 0:
359
assert idx_null is None
360
elif s.null_count() == 0:
361
assert idx_null is None
362
elif s.null_count() == len(s):
363
assert idx_null == 0
364
365