Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_index_of.py
8433 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time, timedelta
4
from decimal import Decimal
5
from typing import TYPE_CHECKING, Any
6
7
import numpy as np
8
import pytest
9
from hypothesis import example, given
10
from hypothesis import strategies as st
11
12
import polars as pl
13
from polars.exceptions import InvalidOperationError
14
from polars.testing import assert_frame_equal
15
from polars.testing.parametric import series
16
17
if TYPE_CHECKING:
18
from polars._typing import IntoExpr, PolarsDataType
19
from polars.datatypes import IntegerType
20
21
22
def isnan(value: object) -> bool:
23
if isinstance(value, int):
24
return False
25
if not isinstance(value, (np.number, float)):
26
return False
27
return np.isnan(value) # type: ignore[no-any-return]
28
29
30
def assert_index_of(
31
series: pl.Series,
32
value: IntoExpr,
33
convert_to_literal: bool = False,
34
) -> None:
35
"""``Series.index_of()`` returns the index, or ``None`` if it can't be found."""
36
if isnan(value):
37
expected_index = None
38
for i, o in enumerate(series.to_list()):
39
if o is not None and np.isnan(o):
40
expected_index = i
41
break
42
else:
43
try:
44
expected_index = series.to_list().index(value)
45
except ValueError:
46
expected_index = None
47
if expected_index == -1:
48
expected_index = None
49
50
if convert_to_literal:
51
value = pl.lit(value, dtype=series.dtype)
52
53
# Eager API:
54
assert series.index_of(value) == expected_index
55
# Lazy API:
56
assert pl.LazyFrame({"series": series}).select(
57
pl.col("series").index_of(value)
58
).collect().get_column("series").to_list() == [expected_index]
59
60
61
@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])
62
def test_float(dtype: pl.DataType) -> None:
63
values = [1.5, np.nan, np.inf, 3.0, None, -np.inf, 0.0, -0.0, -np.nan]
64
if dtype == pl.Float32:
65
# Can't pass Python literals to index_of() for Float32
66
values = [(None if v is None else np.float32(v)) for v in values] # type: ignore[misc]
67
68
series = pl.Series(values, dtype=dtype)
69
sorted_series_asc = series.sort(descending=False)
70
sorted_series_desc = series.sort(descending=True)
71
chunked_series = pl.concat([pl.Series([1, 7], dtype=dtype), series], rechunk=False)
72
73
extra_values = [
74
np.int8(3),
75
np.float32(1.5),
76
np.float32(2**10),
77
]
78
if dtype == pl.Float64:
79
extra_values.extend([np.int32(2**10), np.float64(2**10), np.float64(1.5)])
80
for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:
81
for value in values:
82
assert_index_of(s, value, convert_to_literal=True)
83
assert_index_of(s, value, convert_to_literal=False)
84
for value in extra_values: # type: ignore[assignment]
85
assert_index_of(s, value)
86
87
# -np.nan should match np.nan:
88
assert series.index_of(-np.float32("nan")) == 1 # type: ignore[arg-type]
89
# -0.0 should match 0.0:
90
assert series.index_of(-np.float32(0.0)) == 6 # type: ignore[arg-type]
91
92
93
def test_null() -> None:
94
series = pl.Series([None, None], dtype=pl.Null)
95
assert_index_of(series, None)
96
97
98
def test_empty() -> None:
99
series = pl.Series([], dtype=pl.Null)
100
assert_index_of(series, None)
101
series = pl.Series([], dtype=pl.Int64)
102
assert_index_of(series, None)
103
assert_index_of(series, 12)
104
assert_index_of(series.sort(descending=True), 12)
105
assert_index_of(series.sort(descending=False), 12)
106
107
108
@pytest.mark.parametrize(
109
"dtype",
110
[
111
pl.Int8,
112
pl.Int16,
113
pl.Int32,
114
pl.Int64,
115
pl.Int128,
116
pl.UInt8,
117
pl.UInt16,
118
pl.UInt32,
119
pl.UInt64,
120
pl.UInt128,
121
],
122
)
123
def test_integer(dtype: IntegerType) -> None:
124
print(dtype)
125
dtype_min = dtype.min()
126
dtype_max = pl.Int128.max() if dtype == pl.UInt128 else dtype.max()
127
128
values = [
129
51,
130
3,
131
None,
132
4,
133
pl.select(dtype_max).item(),
134
pl.select(dtype_min).item(),
135
]
136
series = pl.Series(values, dtype=dtype)
137
sorted_series_asc = series.sort(descending=False)
138
sorted_series_desc = series.sort(descending=True)
139
chunked_series = pl.concat(
140
[pl.Series([100, 7], dtype=dtype), series], rechunk=False
141
)
142
143
extra_values = [pl.select(v).item() for v in [dtype_max - 1, dtype_min + 1]]
144
for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:
145
value: IntoExpr
146
for value in values:
147
assert_index_of(s, value, convert_to_literal=True)
148
assert_index_of(s, value, convert_to_literal=False)
149
for value in extra_values:
150
assert_index_of(s, value, convert_to_literal=True)
151
assert_index_of(s, value, convert_to_literal=False)
152
153
# Can't cast floats:
154
for f in [np.float32(3.1), np.float64(3.1), 50.9]:
155
with pytest.raises(InvalidOperationError, match=r"cannot cast.*"):
156
s.index_of(f) # type: ignore[arg-type]
157
158
159
def test_integer_upcast() -> None:
160
series = pl.Series([0, 123, 456, 789], dtype=pl.Int64)
161
for should_work in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16, pl.Int32, pl.UInt32]:
162
assert series.index_of(pl.lit(123, dtype=should_work)) == 1
163
164
165
def test_groupby() -> None:
166
df = pl.DataFrame(
167
{"label": ["a", "b", "a", "b", "a", "b"], "value": [10, 3, 20, 2, 40, 20]}
168
)
169
expected = pl.DataFrame(
170
{"label": ["a", "b"], "value": [1, 2]},
171
schema={"label": pl.String, "value": pl.get_index_type()},
172
)
173
assert_frame_equal(
174
df.group_by("label", maintain_order=True).agg(pl.col("value").index_of(20)),
175
expected,
176
)
177
assert_frame_equal(
178
df.lazy()
179
.group_by("label", maintain_order=True)
180
.agg(pl.col("value").index_of(20))
181
.collect(),
182
expected,
183
)
184
185
186
LISTS_STRATEGY = st.lists(
187
st.one_of(st.none(), st.integers(min_value=10, max_value=50)), max_size=10
188
)
189
190
191
@given(
192
list1=LISTS_STRATEGY,
193
list2=LISTS_STRATEGY,
194
list3=LISTS_STRATEGY,
195
)
196
# The examples are cases where this test previously caught bugs:
197
@example([], [], [None])
198
@pytest.mark.slow
199
def test_randomized(
200
list1: list[int | None], list2: list[int | None], list3: list[int | None]
201
) -> None:
202
series = pl.concat(
203
[pl.Series(values, dtype=pl.Int8) for values in [list1, list2, list3]],
204
rechunk=False,
205
)
206
sorted_series = series.sort(descending=False)
207
sorted_series2 = series.sort(descending=True)
208
209
# Values are between 10 and 50, plus add None and max/min range values:
210
for i in set(range(10, 51)) | {-128, 127, None}:
211
assert_index_of(series, i)
212
assert_index_of(sorted_series, i)
213
assert_index_of(sorted_series2, i)
214
215
216
ENUM = pl.Enum(["a", "b", "c"])
217
218
219
@pytest.mark.parametrize(
220
("series", "extra_values", "sortable"),
221
[
222
(pl.Series(["abc", None, "bb"]), ["", "🚲"], True),
223
(pl.Series([True, None, False, True, False]), [], True),
224
(
225
pl.Series([datetime(1997, 12, 31), datetime(1996, 1, 1)]),
226
[datetime(2023, 12, 12, 16, 12, 39)],
227
True,
228
),
229
(
230
pl.Series([date(1997, 12, 31), None, date(1996, 1, 1)]),
231
[date(2023, 12, 12)],
232
True,
233
),
234
(
235
pl.Series([time(16, 12, 31), None, time(11, 10, 53)]),
236
[time(11, 12, 16)],
237
True,
238
),
239
(
240
pl.Series(
241
[timedelta(hours=12), None, timedelta(minutes=3)],
242
),
243
[timedelta(minutes=17)],
244
True,
245
),
246
(pl.Series([[1, 2], None, [4, 5], [6], [None, 3, 5]]), [[5, 7], []], True),
247
(
248
pl.Series([[[1, 2]], None, [[4, 5]], [[6]], [[None, 3, 5]], [None]]),
249
[[[5, 7]], []],
250
True,
251
),
252
(
253
pl.Series([[1, 2], None, [4, 5], [None, 3]], dtype=pl.Array(pl.Int64(), 2)),
254
[[5, 7], [None, None]],
255
True,
256
),
257
(
258
pl.Series(
259
[[[1, 2]], [None], [[4, 5]], None, [[None, 3]]],
260
dtype=pl.Array(pl.Array(pl.Int64(), 2), 1),
261
),
262
[[[5, 7]], [[None, None]]],
263
True,
264
),
265
(
266
pl.Series(
267
[{"a": 1, "b": 2}, None, {"a": 3, "b": 4}, {"a": None, "b": 2}],
268
dtype=pl.Struct({"a": pl.Int64(), "b": pl.Int64()}),
269
),
270
[{"a": 7, "b": None}, {"a": 6, "b": 4}],
271
False,
272
),
273
(pl.Series([b"abc", None, b"xxx"]), [b"\x0025"], True),
274
(
275
pl.Series(
276
[Decimal(12), None, Decimal(3), Decimal(-12), Decimal(1) / Decimal(10)],
277
dtype=pl.Decimal(20, 4),
278
),
279
[Decimal(4), Decimal(-2), Decimal(1) / Decimal(4), Decimal(1) / Decimal(8)],
280
True,
281
),
282
],
283
)
284
def test_other_types(
285
series: pl.Series, extra_values: list[Any], sortable: bool
286
) -> None:
287
expected_values = series.to_list()
288
series_variants = [series, series.drop_nulls()]
289
if sortable:
290
series_variants.extend(
291
[
292
series.sort(descending=False),
293
series.sort(descending=True),
294
]
295
)
296
for s in series_variants:
297
for value in expected_values:
298
assert_index_of(s, value, convert_to_literal=True)
299
assert_index_of(s, value, convert_to_literal=False)
300
# Extra values may not be expressible as literal of correct dtype, so
301
# don't try:
302
for value in extra_values:
303
assert_index_of(s, value)
304
305
306
# Before the output type would be list[idx-type] when no item was found
307
def test_non_found_correct_type() -> None:
308
df = pl.DataFrame(
309
[
310
pl.Series("a", [0, 1], pl.Int32),
311
pl.Series("b", [1, 2], pl.Int32),
312
]
313
)
314
315
assert_frame_equal(
316
df.group_by("a", maintain_order=True).agg(pl.col.b.index_of(1)),
317
pl.DataFrame({"a": [0, 1], "b": [0, None]}),
318
check_dtypes=False,
319
)
320
321
322
def test_error_on_multiple_values() -> None:
323
with pytest.raises(
324
pl.exceptions.InvalidOperationError,
325
match="needle of `index_of` can only contain",
326
):
327
pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3]))
328
329
330
@pytest.mark.parametrize(
331
"convert_to_literal",
332
[
333
True,
334
False,
335
],
336
)
337
def test_enum(convert_to_literal: bool) -> None:
338
series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"]))
339
expected_values = series.to_list()
340
for s in [
341
series,
342
series.drop_nulls(),
343
series.sort(descending=False),
344
series.sort(descending=True),
345
]:
346
for value in expected_values:
347
assert_index_of(s, value, convert_to_literal=convert_to_literal)
348
349
350
@pytest.mark.parametrize(
351
"convert_to_literal",
352
[True, False],
353
)
354
def test_categorical(convert_to_literal: bool) -> None:
355
series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)
356
expected_values = series.to_list()
357
for s in [
358
series,
359
series.drop_nulls(),
360
series.sort(descending=False),
361
series.sort(descending=True),
362
]:
363
for value in expected_values:
364
assert_index_of(s, value, convert_to_literal=convert_to_literal)
365
366
367
@pytest.mark.parametrize("value", [0, 0.1])
368
def test_categorical_wrong_type_keys_dont_work(value: int | float) -> None:
369
series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)
370
msg = "cannot cast.*losslessly.*"
371
with pytest.raises(InvalidOperationError, match=msg):
372
series.index_of(value)
373
df = pl.DataFrame({"s": series})
374
with pytest.raises(InvalidOperationError, match=msg):
375
df.select(pl.col("s").index_of(value))
376
377
378
@given(s=series(name="s", allow_chunks=True, max_size=10))
379
def test_index_of_null_parametric(s: pl.Series) -> None:
380
idx_null = s.index_of(None)
381
if s.len() == 0:
382
assert idx_null is None
383
elif s.null_count() == 0:
384
assert idx_null is None
385
elif s.null_count() == len(s):
386
assert idx_null == 0
387
388
389
def test_out_of_range_integers() -> None:
390
series = pl.Series([0, 100, None, 1, 2], dtype=pl.Int8)
391
with pytest.raises(InvalidOperationError, match="cannot cast 128 losslessly to i8"):
392
assert series.index_of(128)
393
with pytest.raises(
394
InvalidOperationError, match="cannot cast -200 losslessly to i8"
395
):
396
assert series.index_of(-200)
397
398
399
def test_out_of_range_decimal() -> None:
400
# Up to 34 digits of integers:
401
series = pl.Series([1, None], dtype=pl.Decimal(36, 2))
402
assert series.index_of(10**34 - 1) is None
403
assert series.index_of(-(10**34 - 1)) is None
404
out_of_range = 10**34
405
with pytest.raises(
406
InvalidOperationError, match=f"cannot cast {out_of_range} losslessly"
407
):
408
assert series.index_of(out_of_range)
409
with pytest.raises(
410
InvalidOperationError, match=f"cannot cast {-out_of_range} losslessly"
411
):
412
assert series.index_of(-out_of_range)
413
414
415
def test_out_of_range_float64() -> None:
416
series = pl.Series([0, 255, None], dtype=pl.Float64)
417
# Small numbers are fine:
418
assert series.index_of(1_000_000) is None
419
assert series.index_of(-1_000_000) is None
420
with pytest.raises(
421
InvalidOperationError, match=f"cannot cast {2**53} losslessly to f64"
422
):
423
assert series.index_of(2**53)
424
with pytest.raises(
425
InvalidOperationError, match=f"cannot cast {-(2**53)} losslessly to f64"
426
):
427
assert series.index_of(-(2**53))
428
429
430
def test_out_of_range_float32() -> None:
431
series = pl.Series([0, 255, None], dtype=pl.Float32)
432
# Small numbers are fine:
433
assert series.index_of(1_000_000) is None
434
assert series.index_of(-1_000_000) is None
435
with pytest.raises(
436
InvalidOperationError, match=f"cannot cast {2**24} losslessly to f32"
437
):
438
assert series.index_of(2**24)
439
with pytest.raises(
440
InvalidOperationError, match=f"cannot cast {-(2**24)} losslessly to f32"
441
):
442
assert series.index_of(-(2**24))
443
444
445
def assert_lossy_cast_rejected(
446
series_dtype: PolarsDataType, value: Any, value_dtype: PolarsDataType
447
) -> None:
448
# We create a Series with a null because previously lossless casts would
449
# sometimes get turned into nulls and you'd get an answer.
450
series = pl.Series([None], dtype=series_dtype)
451
with pytest.raises(InvalidOperationError, match="cannot cast losslessly"):
452
series.index_of(pl.lit(value, dtype=value_dtype))
453
454
455
@pytest.mark.parametrize(
456
("series_dtype", "value", "value_dtype"),
457
[
458
# Completely incompatible:
459
(pl.String, 1, pl.UInt8),
460
(pl.UInt8, "1", pl.String),
461
# Larger integer doesn't fit in smaller integer:
462
(pl.UInt8, 17, pl.UInt16),
463
# Can't find negative numbers in unsigned integers:
464
(pl.UInt16, -1, pl.Int8),
465
# Values after the decimal point that can't be represented:
466
(pl.Decimal(3, 1), 1, pl.Decimal(4, 2)),
467
# Can't fit in Decimal:
468
(pl.Decimal(3, 0), 1, pl.Decimal(4, 0)),
469
(pl.Decimal(5, 2), 1, pl.Decimal(5, 1)),
470
(pl.Decimal(5, 2), 1, pl.UInt16),
471
# Can't fit nanoseconds in milliseconds:
472
(pl.Duration("ms"), 1, pl.Duration("ns")),
473
# Arrays that are the wrong length:
474
(pl.Array(pl.Int64, 2), [1], pl.Array(pl.Int64, 1)),
475
# Struct with wrong number of fields:
476
(
477
pl.Struct({"a": pl.Int64, "b": pl.Int64}),
478
{"a": 1},
479
pl.Struct({"a": pl.Int64}),
480
),
481
# Struct with different field name:
482
(pl.Struct({"a": pl.Int64}), {"x": 1}, pl.Struct({"x": pl.Int64})),
483
],
484
)
485
def test_lossy_casts_are_rejected(
486
series_dtype: PolarsDataType, value: Any, value_dtype: PolarsDataType
487
) -> None:
488
assert_lossy_cast_rejected(series_dtype, value, value_dtype)
489
490
491
def test_lossy_casts_are_rejected_nested_dtypes() -> None:
492
# Make sure casting rules are applied recursively for Lists, Arrays,
493
# Struct:
494
series_dtype, value, value_dtype = pl.UInt8, 17, pl.UInt16
495
assert_lossy_cast_rejected(pl.List(series_dtype), [value], pl.List(value_dtype))
496
assert_lossy_cast_rejected(
497
pl.Array(series_dtype, 1), [value], pl.Array(value_dtype, 1)
498
)
499
assert_lossy_cast_rejected(
500
pl.Struct({"key": series_dtype}),
501
{"key": value},
502
pl.Struct({"key": value_dtype}),
503
)
504
505
506
def test_decimal_search_for_int() -> None:
507
values = [Decimal(-12), Decimal(12), Decimal(30)]
508
series = pl.Series(values, dtype=pl.Decimal(4, 1))
509
for i, value in enumerate(values):
510
assert series.index_of(value) == i
511
assert series.index_of(int(value)) == i
512
assert series.index_of(np.int8(value)) == i # type: ignore[arg-type]
513
# Decimal's integer range is 3 digits (3 == 4 - 1), so int8 fits:
514
assert series.index_of(np.int8(127)) is None # type: ignore[arg-type]
515
assert series.index_of(np.int8(-128)) is None # type: ignore[arg-type]
516
517