Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/series/test_getitem.py
6939 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import hypothesis.strategies as st
6
import numpy as np
7
import pytest
8
from hypothesis import given
9
10
import polars as pl
11
from polars.testing import assert_series_equal
12
from polars.testing.parametric import series
13
14
15
@given(
16
srs=series(max_size=10, dtype=pl.Int64),
17
start=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]),
18
stop=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]),
19
step=st.sampled_from([-5, -4, -3, -2, -1, None, 1, 2, 3, 4, 5]),
20
)
21
def test_series_getitem(
22
srs: pl.Series,
23
start: int | None,
24
stop: int | None,
25
step: int | None,
26
) -> None:
27
py_data = srs.to_list()
28
29
s = slice(start, stop, step)
30
sliced_py_data = py_data[s]
31
sliced_pl_data = srs[s].to_list()
32
33
assert sliced_py_data == sliced_pl_data, f"slice [{start}:{stop}:{step}] failed"
34
assert_series_equal(srs, srs, check_exact=True)
35
36
37
@pytest.mark.parametrize(
38
("rng", "expected_values"),
39
[
40
(range(2), [1, 2]),
41
(range(1, 4), [2, 3, 4]),
42
(range(3, 0, -2), [4, 2]),
43
],
44
)
45
def test_series_getitem_range(rng: range, expected_values: list[int]) -> None:
46
s = pl.Series([1, 2, 3, 4])
47
result = s[rng]
48
expected = pl.Series(expected_values)
49
assert_series_equal(result, expected)
50
51
52
@pytest.mark.parametrize(
53
"mask",
54
[
55
[True, False, True],
56
pl.Series([True, False, True]),
57
np.array([True, False, True]),
58
],
59
)
60
def test_series_getitem_boolean_mask(mask: Any) -> None:
61
s = pl.Series([1, 2, 3])
62
print(mask)
63
with pytest.raises(
64
TypeError,
65
match="selecting rows by passing a boolean mask to `__getitem__` is not supported",
66
):
67
s[mask]
68
69
70
@pytest.mark.parametrize(
71
"input", [[], (), pl.Series(dtype=pl.Int64), np.array([], dtype=np.uint32)]
72
)
73
def test_series_getitem_empty_inputs(input: Any) -> None:
74
s = pl.Series("a", ["x", "y", "z"], dtype=pl.String)
75
result = s[input]
76
expected = pl.Series("a", dtype=pl.String)
77
assert_series_equal(result, expected)
78
79
80
@pytest.mark.parametrize("indices", [[0, 2], pl.Series([0, 2]), np.array([0, 2])])
81
def test_series_getitem_multiple_indices(indices: Any) -> None:
82
s = pl.Series(["x", "y", "z"])
83
result = s[indices]
84
expected = pl.Series(["x", "z"])
85
assert_series_equal(result, expected)
86
87
88
def test_series_getitem_numpy() -> None:
89
s = pl.Series([9, 8, 7])
90
91
assert s[np.array([0, 2])].to_list() == [9, 7]
92
assert s[np.array([-1, -3])].to_list() == [7, 9]
93
assert s[np.array(-2)].to_list() == [8]
94
95
96
@pytest.mark.parametrize(
97
("input", "match"),
98
[
99
(
100
[0.0, 1.0],
101
"cannot select elements using Sequence with elements of type 'float'",
102
),
103
(
104
"foobar",
105
"cannot select elements using Sequence with elements of type 'str'",
106
),
107
(
108
pl.Series([[1, 2], [3, 4]]),
109
"cannot treat Series of type List\\(Int64\\) as indices",
110
),
111
(np.array([0.0, 1.0]), "cannot treat NumPy array of type float64 as indices"),
112
(object(), "cannot select elements using key of type 'object'"),
113
],
114
)
115
def test_series_getitem_col_invalid_inputs(input: Any, match: str) -> None:
116
s = pl.Series([1, 2, 3])
117
with pytest.raises(TypeError, match=match):
118
s[input]
119
120