Path: blob/main/py-polars/tests/unit/series/test_getitem.py
6939 views
from __future__ import annotations12from typing import Any34import hypothesis.strategies as st5import numpy as np6import pytest7from hypothesis import given89import polars as pl10from polars.testing import assert_series_equal11from polars.testing.parametric import series121314@given(15srs=series(max_size=10, dtype=pl.Int64),16start=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]),17stop=st.sampled_from([-5, -4, -3, -2, -1, None, 0, 1, 2, 3, 4, 5]),18step=st.sampled_from([-5, -4, -3, -2, -1, None, 1, 2, 3, 4, 5]),19)20def test_series_getitem(21srs: pl.Series,22start: int | None,23stop: int | None,24step: int | None,25) -> None:26py_data = srs.to_list()2728s = slice(start, stop, step)29sliced_py_data = py_data[s]30sliced_pl_data = srs[s].to_list()3132assert sliced_py_data == sliced_pl_data, f"slice [{start}:{stop}:{step}] failed"33assert_series_equal(srs, srs, check_exact=True)343536@pytest.mark.parametrize(37("rng", "expected_values"),38[39(range(2), [1, 2]),40(range(1, 4), [2, 3, 4]),41(range(3, 0, -2), [4, 2]),42],43)44def test_series_getitem_range(rng: range, expected_values: list[int]) -> None:45s = pl.Series([1, 2, 3, 4])46result = s[rng]47expected = pl.Series(expected_values)48assert_series_equal(result, expected)495051@pytest.mark.parametrize(52"mask",53[54[True, False, True],55pl.Series([True, False, True]),56np.array([True, False, True]),57],58)59def test_series_getitem_boolean_mask(mask: Any) -> None:60s = pl.Series([1, 2, 3])61print(mask)62with pytest.raises(63TypeError,64match="selecting rows by passing a boolean mask to `__getitem__` is not supported",65):66s[mask]676869@pytest.mark.parametrize(70"input", [[], (), pl.Series(dtype=pl.Int64), np.array([], dtype=np.uint32)]71)72def test_series_getitem_empty_inputs(input: Any) -> None:73s = pl.Series("a", ["x", "y", "z"], dtype=pl.String)74result = s[input]75expected = pl.Series("a", dtype=pl.String)76assert_series_equal(result, expected)777879@pytest.mark.parametrize("indices", [[0, 2], pl.Series([0, 2]), np.array([0, 2])])80def test_series_getitem_multiple_indices(indices: Any) -> None:81s = pl.Series(["x", "y", "z"])82result = s[indices]83expected = pl.Series(["x", "z"])84assert_series_equal(result, expected)858687def test_series_getitem_numpy() -> None:88s = pl.Series([9, 8, 7])8990assert s[np.array([0, 2])].to_list() == [9, 7]91assert s[np.array([-1, -3])].to_list() == [7, 9]92assert s[np.array(-2)].to_list() == [8]939495@pytest.mark.parametrize(96("input", "match"),97[98(99[0.0, 1.0],100"cannot select elements using Sequence with elements of type 'float'",101),102(103"foobar",104"cannot select elements using Sequence with elements of type 'str'",105),106(107pl.Series([[1, 2], [3, 4]]),108"cannot treat Series of type List\\(Int64\\) as indices",109),110(np.array([0.0, 1.0]), "cannot treat NumPy array of type float64 as indices"),111(object(), "cannot select elements using key of type 'object'"),112],113)114def test_series_getitem_col_invalid_inputs(input: Any, match: str) -> None:115s = pl.Series([1, 2, 3])116with pytest.raises(TypeError, match=match):117s[input]118119120