Path: blob/main/py-polars/tests/unit/operations/test_index_of.py
6939 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from decimal import Decimal4from typing import TYPE_CHECKING, Any56import numpy as np7import pytest8from hypothesis import example, given9from hypothesis import strategies as st1011import polars as pl12from polars.exceptions import InvalidOperationError13from polars.testing import assert_frame_equal14from polars.testing.parametric import series1516if TYPE_CHECKING:17from polars._typing import IntoExpr181920def isnan(value: object) -> bool:21if isinstance(value, int):22return False23if not isinstance(value, (np.number, float)):24return False25return np.isnan(value) # type: ignore[no-any-return]262728def assert_index_of(29series: pl.Series,30value: IntoExpr,31convert_to_literal: bool = False,32) -> None:33"""``Series.index_of()`` returns the index, or ``None`` if it can't be found."""34if isnan(value):35expected_index = None36for i, o in enumerate(series.to_list()):37if o is not None and np.isnan(o):38expected_index = i39break40else:41try:42expected_index = series.to_list().index(value)43except ValueError:44expected_index = None45if expected_index == -1:46expected_index = None4748if convert_to_literal:49value = pl.lit(value, dtype=series.dtype)5051# Eager API:52assert series.index_of(value) == expected_index53# Lazy API:54assert pl.LazyFrame({"series": series}).select(55pl.col("series").index_of(value)56).collect().get_column("series").to_list() == [expected_index]575859@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])60def test_float(dtype: pl.DataType) -> None:61values = [1.5, np.nan, np.inf, 3.0, None, -np.inf, 0.0, -0.0, -np.nan]62series = pl.Series(values, dtype=dtype)63sorted_series_asc = series.sort(descending=False)64sorted_series_desc = series.sort(descending=True)65chunked_series = pl.concat([pl.Series([1, 7], dtype=dtype), series], rechunk=False)6667extra_values = [68np.int8(3),69np.int64(2**42),70np.float64(1.5),71np.float32(1.5),72np.float32(2**37),73np.float64(2**100),74]75for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:76for value in values:77assert_index_of(s, value, convert_to_literal=True)78assert_index_of(s, value, convert_to_literal=False)79for value in extra_values: # type: ignore[assignment]80assert_index_of(s, value)8182# Explicitly check some extra-tricky edge cases:83assert series.index_of(-np.nan) == 1 # -np.nan should match np.nan84assert series.index_of(-0.0) == 6 # -0.0 should match 0.0858687def test_null() -> None:88series = pl.Series([None, None], dtype=pl.Null)89assert_index_of(series, None)909192def test_empty() -> None:93series = pl.Series([], dtype=pl.Null)94assert_index_of(series, None)95series = pl.Series([], dtype=pl.Int64)96assert_index_of(series, None)97assert_index_of(series, 12)98assert_index_of(series.sort(descending=True), 12)99assert_index_of(series.sort(descending=False), 12)100101102@pytest.mark.parametrize(103"dtype",104[105pl.Int8,106pl.Int16,107pl.Int32,108pl.Int64,109pl.UInt8,110pl.UInt16,111pl.UInt32,112pl.UInt64,113pl.Int128,114],115)116def test_integer(dtype: pl.DataType) -> None:117values = [11851,1193,120None,1214,122pl.select(dtype.max()).item(), # type: ignore[attr-defined]123pl.select(dtype.min()).item(), # type: ignore[attr-defined]124]125series = pl.Series(values, dtype=dtype)126sorted_series_asc = series.sort(descending=False)127sorted_series_desc = series.sort(descending=True)128chunked_series = pl.concat(129[pl.Series([100, 7], dtype=dtype), series], rechunk=False130)131132extra_values = [pl.select(v).item() for v in [dtype.max() - 1, dtype.min() + 1]] # type: ignore[attr-defined]133for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:134value: IntoExpr135for value in values:136assert_index_of(s, value, convert_to_literal=True)137assert_index_of(s, value, convert_to_literal=False)138for value in extra_values:139assert_index_of(s, value, convert_to_literal=True)140assert_index_of(s, value, convert_to_literal=False)141142# Can't cast floats:143for f in [np.float32(3.1), np.float64(3.1), 50.9]:144with pytest.raises(InvalidOperationError, match="cannot cast lossless"):145s.index_of(f) # type: ignore[arg-type]146147148def test_groupby() -> None:149df = pl.DataFrame(150{"label": ["a", "b", "a", "b", "a", "b"], "value": [10, 3, 20, 2, 40, 20]}151)152expected = pl.DataFrame(153{"label": ["a", "b"], "value": [1, 2]},154schema={"label": pl.String, "value": pl.UInt32},155)156assert_frame_equal(157df.group_by("label", maintain_order=True).agg(pl.col("value").index_of(20)),158expected,159)160assert_frame_equal(161df.lazy()162.group_by("label", maintain_order=True)163.agg(pl.col("value").index_of(20))164.collect(),165expected,166)167168169LISTS_STRATEGY = st.lists(170st.one_of(st.none(), st.integers(min_value=10, max_value=50)), max_size=10171)172173174@given(175list1=LISTS_STRATEGY,176list2=LISTS_STRATEGY,177list3=LISTS_STRATEGY,178)179# The examples are cases where this test previously caught bugs:180@example([], [], [None])181@pytest.mark.slow182def test_randomized(183list1: list[int | None], list2: list[int | None], list3: list[int | None]184) -> None:185series = pl.concat(186[pl.Series(values, dtype=pl.Int8) for values in [list1, list2, list3]],187rechunk=False,188)189sorted_series = series.sort(descending=False)190sorted_series2 = series.sort(descending=True)191192# Values are between 10 and 50, plus add None and max/min range values:193for i in set(range(10, 51)) | {-128, 127, None}:194assert_index_of(series, i)195assert_index_of(sorted_series, i)196assert_index_of(sorted_series2, i)197198199ENUM = pl.Enum(["a", "b", "c"])200201202@pytest.mark.parametrize(203("series", "extra_values", "sortable"),204[205(pl.Series(["abc", None, "bb"]), ["", "🚲"], True),206(pl.Series([True, None, False, True, False]), [], True),207(208pl.Series([datetime(1997, 12, 31), datetime(1996, 1, 1)]),209[datetime(2023, 12, 12, 16, 12, 39)],210True,211),212(213pl.Series([date(1997, 12, 31), None, date(1996, 1, 1)]),214[date(2023, 12, 12)],215True,216),217(218pl.Series([time(16, 12, 31), None, time(11, 10, 53)]),219[time(11, 12, 16)],220True,221),222(223pl.Series(224[timedelta(hours=12), None, timedelta(minutes=3)],225),226[timedelta(minutes=17)],227True,228),229(pl.Series([[1, 2], None, [4, 5], [6], [None, 3, 5]]), [[5, 7], []], True),230(231pl.Series([[[1, 2]], None, [[4, 5]], [[6]], [[None, 3, 5]], [None]]),232[[[5, 7]], []],233True,234),235(236pl.Series([[1, 2], None, [4, 5], [None, 3]], dtype=pl.Array(pl.Int64(), 2)),237[[5, 7], [None, None]],238True,239),240(241pl.Series(242[[[1, 2]], [None], [[4, 5]], None, [[None, 3]]],243dtype=pl.Array(pl.Array(pl.Int64(), 2), 1),244),245[[[5, 7]], [[None, None]]],246True,247),248(249pl.Series(250[{"a": 1, "b": 2}, None, {"a": 3, "b": 4}, {"a": None, "b": 2}],251dtype=pl.Struct({"a": pl.Int64(), "b": pl.Int64()}),252),253[{"a": 7, "b": None}, {"a": 6, "b": 4}],254False,255),256(pl.Series([b"abc", None, b"xxx"]), [b"\x0025"], True),257(pl.Series([Decimal(12), None, Decimal(3)]), [Decimal(4)], True),258],259)260def test_other_types(261series: pl.Series, extra_values: list[Any], sortable: bool262) -> None:263expected_values = series.to_list()264series_variants = [series, series.drop_nulls()]265if sortable:266series_variants.extend(267[268series.sort(descending=False),269series.sort(descending=True),270]271)272for s in series_variants:273for value in expected_values:274assert_index_of(s, value, convert_to_literal=True)275assert_index_of(s, value, convert_to_literal=False)276# Extra values may not be expressible as literal of correct dtype, so277# don't try:278for value in extra_values:279assert_index_of(s, value)280281282# Before the output type would be list[idx-type] when no item was found283def test_non_found_correct_type() -> None:284df = pl.DataFrame(285[286pl.Series("a", [0, 1], pl.Int32),287pl.Series("b", [1, 2], pl.Int32),288]289)290291assert_frame_equal(292df.group_by("a", maintain_order=True).agg(pl.col.b.index_of(1)),293pl.DataFrame({"a": [0, 1], "b": [0, None]}),294check_dtypes=False,295)296297298def test_error_on_multiple_values() -> None:299with pytest.raises(300pl.exceptions.InvalidOperationError,301match="needle of `index_of` can only contain",302):303pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3]))304305306@pytest.mark.parametrize(307"convert_to_literal",308[309True,310False,311],312)313def test_enum(convert_to_literal: bool) -> None:314series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"]))315expected_values = series.to_list()316for s in [317series,318series.drop_nulls(),319series.sort(descending=False),320series.sort(descending=True),321]:322for value in expected_values:323assert_index_of(s, value, convert_to_literal=convert_to_literal)324325326@pytest.mark.parametrize(327"convert_to_literal",328[True, False],329)330def test_categorical(convert_to_literal: bool) -> None:331series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)332expected_values = series.to_list()333for s in [334series,335series.drop_nulls(),336series.sort(descending=False),337series.sort(descending=True),338]:339for value in expected_values:340assert_index_of(s, value, convert_to_literal=convert_to_literal)341342343@pytest.mark.parametrize("value", [0, 0.1])344def test_categorical_wrong_type_keys_dont_work(value: int | float) -> None:345series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)346msg = "cannot cast lossless"347with pytest.raises(InvalidOperationError, match=msg):348series.index_of(value)349df = pl.DataFrame({"s": series})350with pytest.raises(InvalidOperationError, match=msg):351df.select(pl.col("s").index_of(value))352353354@given(s=series(name="s", allow_chunks=True, max_size=10))355def test_index_of_null_parametric(s: pl.Series) -> None:356idx_null = s.index_of(None)357if s.len() == 0:358assert idx_null is None359elif s.null_count() == 0:360assert idx_null is None361elif s.null_count() == len(s):362assert idx_null == 0363364365