Path: blob/main/py-polars/tests/unit/dataframe/test_getitem.py
6939 views
from __future__ import annotations12from typing import Any34import hypothesis.strategies as st5import numpy as np6import pytest7from hypothesis import given89import polars as pl10from polars.testing import assert_frame_equal, assert_series_equal11from polars.testing.parametric import column, dataframes12from tests.unit.conftest import INTEGER_DTYPES, SIGNED_INTEGER_DTYPES131415@given(16df=dataframes(17max_size=10,18cols=[19column(20"start",21dtype=pl.Int8,22allow_null=True,23strategy=st.integers(min_value=-8, max_value=8),24),25column(26"stop",27dtype=pl.Int8,28allow_null=True,29strategy=st.integers(min_value=-6, max_value=6),30),31column(32"step",33dtype=pl.Int8,34allow_null=True,35strategy=st.integers(min_value=-4, max_value=4).filter(36lambda x: x != 037),38),39column("misc", dtype=pl.Int32),40],41)42# generated dataframe example -43# ┌───────┬──────┬──────┬───────┐44# │ start ┆ stop ┆ step ┆ misc │45# │ --- ┆ --- ┆ --- ┆ --- │46# │ i8 ┆ i8 ┆ i8 ┆ i32 │47# ╞═══════╪══════╪══════╪═══════╡48# │ 2 ┆ -1 ┆ null ┆ -55 │49# │ -3 ┆ 0 ┆ -2 ┆ 61582 │50# │ null ┆ 1 ┆ 2 ┆ 5865 │51# └───────┴──────┴──────┴───────┘52)53def test_df_getitem_row_slice(df: pl.DataFrame) -> None:54# take strategy-generated integer values from the frame as slice bounds.55# use these bounds to slice the same frame, and then validate the result56# against a py-native slice of the same data using the same bounds.57#58# given the average number of rows in the frames, and the value of59# max_examples, this will result in close to 5000 test permutations,60# running in around ~1.5 secs (depending on hardware/etc).61py_data = df.rows()6263for start, stop, step, _ in py_data:64s = slice(start, stop, step)65sliced_py_data = py_data[s]66sliced_df_data = df[s].rows()6768assert sliced_py_data == sliced_df_data, (69f"slice [{start}:{stop}:{step}] failed on df w/len={df.height}"70)717273def test_df_getitem_col_single_name() -> None:74df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})75result = df[:, "a"]76expected = df.select("a").to_series()77assert_series_equal(result, expected)787980@pytest.mark.parametrize(81("input", "expected_cols"),82[83(["a"], ["a"]),84(["a", "d"], ["a", "d"]),85(slice("b", "d"), ["b", "c", "d"]),86(pl.Series(["a", "b"]), ["a", "b"]),87(np.array(["c", "d"]), ["c", "d"]),88],89)90def test_df_getitem_col_multiple_names(input: Any, expected_cols: list[str]) -> None:91df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})92result = df[:, input]93expected = df.select(expected_cols)94assert_frame_equal(result, expected)959697def test_df_getitem_col_single_index() -> None:98df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})99result = df[:, 1]100expected = df.select("b").to_series()101assert_series_equal(result, expected)102103104def test_df_getitem_col_two_entries() -> None:105df = pl.DataFrame({"x": [1.0], "y": [1.0]})106107assert_frame_equal(df["x", "y"], df)108assert_frame_equal(df[True, True], df)109110111@pytest.mark.parametrize(112("input", "expected_cols"),113[114([0], ["a"]),115([0, 3], ["a", "d"]),116(slice(1, 4), ["b", "c", "d"]),117(pl.Series([0, 1]), ["a", "b"]),118(np.array([2, 3]), ["c", "d"]),119],120)121def test_df_getitem_col_multiple_indices(input: Any, expected_cols: list[str]) -> None:122df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})123result = df[:, input]124expected = df.select(expected_cols)125assert_frame_equal(result, expected)126127128@pytest.mark.parametrize(129"mask",130[131[True, False, True],132pl.Series([True, False, True]),133np.array([True, False, True]),134],135)136def test_df_getitem_col_boolean_mask(mask: Any) -> None:137df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})138result = df[:, mask]139expected = df.select("a", "c")140assert_frame_equal(result, expected)141142143@pytest.mark.parametrize(144("rng", "expected_cols"),145[146(range(2), ["a", "b"]),147(range(1, 4), ["b", "c", "d"]),148(range(3, 0, -2), ["d", "b"]),149],150)151def test_df_getitem_col_range(rng: range, expected_cols: list[str]) -> None:152df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]})153result = df[:, rng]154expected = df.select(expected_cols)155assert_frame_equal(result, expected)156157158@pytest.mark.parametrize(159"input", [[], (), pl.Series(dtype=pl.Int64), np.array([], dtype=np.uint32)]160)161def test_df_getitem_col_empty_inputs(input: Any) -> None:162df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]})163result = df[:, input]164expected = pl.DataFrame()165assert_frame_equal(result, expected)166167168@pytest.mark.parametrize(169("input", "match"),170[171(172[0.0, 1.0],173"cannot select columns using Sequence with elements of type 'float'",174),175(176pl.Series([[1, 2], [3, 4]]),177"cannot select columns using Series of type List\\(Int64\\)",178),179(180np.array([0.0, 1.0]),181"cannot select columns using NumPy array of type float64",182),183(object(), "cannot select columns using key of type 'object'"),184],185)186def test_df_getitem_col_invalid_inputs(input: Any, match: str) -> None:187df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]})188with pytest.raises(TypeError, match=match):189df[:, input]190191192@pytest.mark.parametrize(193("input", "match"),194[195(["a", 2], "'int' object cannot be converted to 'PyString'"),196([1, "c"], "'str' object cannot be interpreted as an integer"),197],198)199def test_df_getitem_col_mixed_inputs(input: list[Any], match: str) -> None:200df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})201with pytest.raises(TypeError, match=match):202df[:, input]203204205@pytest.mark.parametrize(206("input", "match"),207[208([0.0, 1.0], "unexpected value while building Series of type Int64"),209(210pl.Series([[1, 2], [3, 4]]),211"cannot treat Series of type List\\(Int64\\) as indices",212),213(np.array([0.0, 1.0]), "cannot treat NumPy array of type float64 as indices"),214(object(), "cannot select rows using key of type 'object'"),215],216)217def test_df_getitem_row_invalid_inputs(input: Any, match: str) -> None:218df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]})219with pytest.raises(TypeError, match=match):220df[input, :]221222223def test_df_getitem_row_range() -> None:224df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5.0, 6.0, 7.0, 8.0]})225result = df[range(3, 0, -2), :]226expected = pl.DataFrame({"a": [4, 2], "b": [8.0, 6.0]})227assert_frame_equal(result, expected)228229230def test_df_getitem_row_range_single_input() -> None:231df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5.0, 6.0, 7.0, 8.0]})232result = df[range(1, 3)]233expected = pl.DataFrame({"a": [2, 3], "b": [6.0, 7.0]})234assert_frame_equal(result, expected)235236237def test_df_getitem_row_empty_list_single_input() -> None:238df = pl.DataFrame({"a": [1, 2], "b": [5.0, 6.0]})239result = df[[]]240expected = df.clear()241assert_frame_equal(result, expected)242243244def test_df_getitem() -> None:245"""Test all the methods to use [] on a dataframe."""246df = pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [3, 4, 5, 6]})247248# multiple slices.249# The first element refers to the rows, the second element to columns250assert_frame_equal(df[:, :], df)251252# str, always refers to a column name253assert_series_equal(df["a"], pl.Series("a", [1.0, 2.0, 3.0, 4.0]))254255# int, always refers to a row index (zero-based): index=1 => second row256assert_frame_equal(df[1], pl.DataFrame({"a": [2.0], "b": [4]}))257258# int, int.259# The first element refers to the rows, the second element to columns260assert df[2, 1] == 5261assert df[2, -2] == 3.0262263with pytest.raises(IndexError):264# Column index out of bounds265df[2, 2]266267with pytest.raises(IndexError):268# Column index out of bounds269df[2, -3]270271# int, list[int].272# The first element refers to the rows, the second element to columns273assert_frame_equal(df[2, [1, 0]], pl.DataFrame({"b": [5], "a": [3.0]}))274assert_frame_equal(df[2, [-1, -2]], pl.DataFrame({"b": [5], "a": [3.0]}))275276with pytest.raises(IndexError):277# Column index out of bounds278df[2, [2, 0]]279280with pytest.raises(IndexError):281# Column index out of bounds282df[2, [2, -3]]283284# slice. Below an example of taking every second row285assert_frame_equal(df[1::2], pl.DataFrame({"a": [2.0, 4.0], "b": [4, 6]}))286287# slice, empty slice288assert df[:0].columns == ["a", "b"]289assert len(df[:0]) == 0290291# make mypy happy292empty: list[int] = []293294# empty list with column selector drops rows but keeps columns295assert_frame_equal(df[empty, :], df[:0])296297# sequences (lists or tuples; tuple only if length != 2)298# if strings or list of expressions, assumed to be column names299# if bools, assumed to be a row mask300# if integers, assumed to be row indices301assert_frame_equal(df[["a", "b"]], df)302assert_frame_equal(df.select([pl.col("a"), pl.col("b")]), df)303assert_frame_equal(304df[[1, -4, -1, 2, 1]],305pl.DataFrame({"a": [2.0, 1.0, 4.0, 3.0, 2.0], "b": [4, 3, 6, 5, 4]}),306)307308# pl.Series: strings for column selections.309assert_frame_equal(df[pl.Series("", ["a", "b"])], df)310311# pl.Series: positive idxs or empty idxs for row selection.312for pl_dtype in INTEGER_DTYPES:313assert_frame_equal(314df[pl.Series("", [1, 0, 3, 2, 3, 0], dtype=pl_dtype)],315pl.DataFrame(316{"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]}317),318)319assert df[pl.Series("", [], dtype=pl_dtype)].columns == ["a", "b"]320321# pl.Series: positive and negative idxs for row selection.322for pl_dtype in SIGNED_INTEGER_DTYPES:323assert_frame_equal(324df[pl.Series("", [-1, 0, -3, -2, 3, -4], dtype=pl_dtype)],325pl.DataFrame(326{"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]}327),328)329330# Boolean masks for rows not supported331with pytest.raises(TypeError):332df[[True, False, True], [False, True]]333with pytest.raises(TypeError):334df[pl.Series([True, False, True]), "b"]335336assert_frame_equal(df[np.array([True, False])], df[:, :1])337338# wrong length boolean mask for column selection339with pytest.raises(340ValueError,341match=f"expected {df.width} values when selecting columns by boolean mask",342):343df[:, [True, False, True]]344345346def test_df_getitem_numpy() -> None:347# nupmy getitem: assumed to be row indices if integers, or columns if strings348df = pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [3, 4, 5, 6]})349350# numpy array: positive idxs and empty idx351for np_dtype in (352np.int8,353np.int16,354np.int32,355np.int64,356np.uint8,357np.uint16,358np.uint32,359np.uint64,360):361assert_frame_equal(362df[np.array([1, 0, 3, 2, 3, 0], dtype=np_dtype)],363pl.DataFrame(364{"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]}365),366)367assert df[np.array([], dtype=np_dtype)].columns == ["a", "b"]368369# numpy array: positive and negative idxs.370for np_dtype in (np.int8, np.int16, np.int32, np.int64):371assert_frame_equal(372df[np.array([-1, 0, -3, -2, 3, -4], dtype=np_dtype)],373pl.DataFrame(374{"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]}375),376)377378# zero-dimensional array indexing is equivalent to int row selection379assert_frame_equal(df[np.array(0)], pl.DataFrame({"a": [1.0], "b": [3]}))380assert_frame_equal(df[np.array(1)], pl.DataFrame({"a": [2.0], "b": [4]}))381382# note that we cannot use floats (even if they could be cast to int without loss)383with pytest.raises(384TypeError,385match="cannot select columns using NumPy array of type float",386):387_ = df[np.array([1.0])]388389with pytest.raises(390TypeError,391match="multi-dimensional NumPy arrays not supported as index",392):393df[np.array([[0], [1]])]394395396def test_df_getitem_extended() -> None:397df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0], "c": ["a", "b", "c"]})398399# select columns by mask400assert df[:2, :1].rows() == [(1,), (2,)]401assert df[:2, ["a"]].rows() == [(1,), (2,)]402403# column selection by string(s) in first dimension404assert df["a"].to_list() == [1, 2, 3]405assert df["b"].to_list() == [1.0, 2.0, 3.0]406assert df["c"].to_list() == ["a", "b", "c"]407408# row selection by integers(s) in first dimension409assert_frame_equal(df[0], pl.DataFrame({"a": [1], "b": [1.0], "c": ["a"]}))410assert_frame_equal(df[-1], pl.DataFrame({"a": [3], "b": [3.0], "c": ["c"]}))411412# row, column selection when using two dimensions413assert df[:, "a"].to_list() == [1, 2, 3]414assert df[:, 1].to_list() == [1.0, 2.0, 3.0]415assert df[:2, 2].to_list() == ["a", "b"]416417assert_frame_equal(418df[[1, 2]], pl.DataFrame({"a": [2, 3], "b": [2.0, 3.0], "c": ["b", "c"]})419)420assert_frame_equal(421df[[-1, -2]], pl.DataFrame({"a": [3, 2], "b": [3.0, 2.0], "c": ["c", "b"]})422)423424assert df[["a", "b"]].columns == ["a", "b"]425assert_frame_equal(426df[[1, 2], [1, 2]], pl.DataFrame({"b": [2.0, 3.0], "c": ["b", "c"]})427)428assert df[1, 2] == "b"429assert df[1, 1] == 2.0430assert df[2, 0] == 3431432assert df[[2], ["a", "b"]].rows() == [(3, 3.0)]433assert df.to_series(0).name == "a"434assert (df["a"] == df["a"]).sum() == 3435assert (df["c"] == df["a"].cast(str)).sum() == 0436assert df[:, "a":"b"].rows() == [(1, 1.0), (2, 2.0), (3, 3.0)] # type: ignore[index, misc]437assert df[:, "a":"c"].columns == ["a", "b", "c"] # type: ignore[index, misc]438assert df[:, []].shape == (0, 0)439expect = pl.DataFrame({"c": ["b"]})440assert_frame_equal(df[1, [2]], expect)441expect = pl.DataFrame({"b": [1.0, 3.0]})442assert_frame_equal(df[[0, 2], [1]], expect)443assert df[0, "c"] == "a"444assert df[1, "c"] == "b"445assert df[2, "c"] == "c"446assert df[0, "a"] == 1447448# more slicing449expect = pl.DataFrame({"a": [3, 2, 1], "b": [3.0, 2.0, 1.0], "c": ["c", "b", "a"]})450assert_frame_equal(df[::-1], expect)451expect = pl.DataFrame({"a": [1, 2], "b": [1.0, 2.0], "c": ["a", "b"]})452assert_frame_equal(df[:-1], expect)453454expect = pl.DataFrame({"a": [1, 3], "b": [1.0, 3.0], "c": ["a", "c"]})455assert_frame_equal(df[::2], expect)456457# only allow boolean values in column position458df = pl.DataFrame(459{460"a": [1, 2],461"b": [2, 3],462"c": [3, 4],463}464)465466assert df[:, [False, True, True]].columns == ["b", "c"]467assert df[:, pl.Series([False, True, True])].columns == ["b", "c"]468assert df[:, pl.Series([False, False, False])].columns == []469470471def test_df_getitem_5343() -> None:472# https://github.com/pola-rs/polars/issues/5343473df = pl.DataFrame(474{475f"foo{col}": [n**col for n in range(5)] # 5 rows476for col in range(12) # 12 columns477}478)479assert df[4, 4] == 256480assert df[4, 5] == 1024481assert_frame_equal(df[4, [2]], pl.DataFrame({"foo2": [16]}))482assert_frame_equal(df[4, [5]], pl.DataFrame({"foo5": [1024]}))483484485def test_no_deadlock_19358() -> None:486s = pl.Series(["text"] * 100 + [1] * 100, dtype=pl.Object)487result = s.to_frame()[[0, -1]]488assert result[""].to_list() == ["text", 1]489490491