Path: blob/main/py-polars/tests/unit/operations/test_index_of.py
8433 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from decimal import Decimal4from typing import TYPE_CHECKING, Any56import numpy as np7import pytest8from hypothesis import example, given9from hypothesis import strategies as st1011import polars as pl12from polars.exceptions import InvalidOperationError13from polars.testing import assert_frame_equal14from polars.testing.parametric import series1516if TYPE_CHECKING:17from polars._typing import IntoExpr, PolarsDataType18from polars.datatypes import IntegerType192021def isnan(value: object) -> bool:22if isinstance(value, int):23return False24if not isinstance(value, (np.number, float)):25return False26return np.isnan(value) # type: ignore[no-any-return]272829def assert_index_of(30series: pl.Series,31value: IntoExpr,32convert_to_literal: bool = False,33) -> None:34"""``Series.index_of()`` returns the index, or ``None`` if it can't be found."""35if isnan(value):36expected_index = None37for i, o in enumerate(series.to_list()):38if o is not None and np.isnan(o):39expected_index = i40break41else:42try:43expected_index = series.to_list().index(value)44except ValueError:45expected_index = None46if expected_index == -1:47expected_index = None4849if convert_to_literal:50value = pl.lit(value, dtype=series.dtype)5152# Eager API:53assert series.index_of(value) == expected_index54# Lazy API:55assert pl.LazyFrame({"series": series}).select(56pl.col("series").index_of(value)57).collect().get_column("series").to_list() == [expected_index]585960@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])61def test_float(dtype: pl.DataType) -> None:62values = [1.5, np.nan, np.inf, 3.0, None, -np.inf, 0.0, -0.0, -np.nan]63if dtype == pl.Float32:64# Can't pass Python literals to index_of() for Float3265values = [(None if v is None else np.float32(v)) for v in values] # type: ignore[misc]6667series = pl.Series(values, dtype=dtype)68sorted_series_asc = series.sort(descending=False)69sorted_series_desc = series.sort(descending=True)70chunked_series = pl.concat([pl.Series([1, 7], dtype=dtype), series], rechunk=False)7172extra_values = [73np.int8(3),74np.float32(1.5),75np.float32(2**10),76]77if dtype == pl.Float64:78extra_values.extend([np.int32(2**10), np.float64(2**10), np.float64(1.5)])79for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:80for value in values:81assert_index_of(s, value, convert_to_literal=True)82assert_index_of(s, value, convert_to_literal=False)83for value in extra_values: # type: ignore[assignment]84assert_index_of(s, value)8586# -np.nan should match np.nan:87assert series.index_of(-np.float32("nan")) == 1 # type: ignore[arg-type]88# -0.0 should match 0.0:89assert series.index_of(-np.float32(0.0)) == 6 # type: ignore[arg-type]909192def test_null() -> None:93series = pl.Series([None, None], dtype=pl.Null)94assert_index_of(series, None)959697def test_empty() -> None:98series = pl.Series([], dtype=pl.Null)99assert_index_of(series, None)100series = pl.Series([], dtype=pl.Int64)101assert_index_of(series, None)102assert_index_of(series, 12)103assert_index_of(series.sort(descending=True), 12)104assert_index_of(series.sort(descending=False), 12)105106107@pytest.mark.parametrize(108"dtype",109[110pl.Int8,111pl.Int16,112pl.Int32,113pl.Int64,114pl.Int128,115pl.UInt8,116pl.UInt16,117pl.UInt32,118pl.UInt64,119pl.UInt128,120],121)122def test_integer(dtype: IntegerType) -> None:123print(dtype)124dtype_min = dtype.min()125dtype_max = pl.Int128.max() if dtype == pl.UInt128 else dtype.max()126127values = [12851,1293,130None,1314,132pl.select(dtype_max).item(),133pl.select(dtype_min).item(),134]135series = pl.Series(values, dtype=dtype)136sorted_series_asc = series.sort(descending=False)137sorted_series_desc = series.sort(descending=True)138chunked_series = pl.concat(139[pl.Series([100, 7], dtype=dtype), series], rechunk=False140)141142extra_values = [pl.select(v).item() for v in [dtype_max - 1, dtype_min + 1]]143for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:144value: IntoExpr145for value in values:146assert_index_of(s, value, convert_to_literal=True)147assert_index_of(s, value, convert_to_literal=False)148for value in extra_values:149assert_index_of(s, value, convert_to_literal=True)150assert_index_of(s, value, convert_to_literal=False)151152# Can't cast floats:153for f in [np.float32(3.1), np.float64(3.1), 50.9]:154with pytest.raises(InvalidOperationError, match=r"cannot cast.*"):155s.index_of(f) # type: ignore[arg-type]156157158def test_integer_upcast() -> None:159series = pl.Series([0, 123, 456, 789], dtype=pl.Int64)160for should_work in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16, pl.Int32, pl.UInt32]:161assert series.index_of(pl.lit(123, dtype=should_work)) == 1162163164def test_groupby() -> None:165df = pl.DataFrame(166{"label": ["a", "b", "a", "b", "a", "b"], "value": [10, 3, 20, 2, 40, 20]}167)168expected = pl.DataFrame(169{"label": ["a", "b"], "value": [1, 2]},170schema={"label": pl.String, "value": pl.get_index_type()},171)172assert_frame_equal(173df.group_by("label", maintain_order=True).agg(pl.col("value").index_of(20)),174expected,175)176assert_frame_equal(177df.lazy()178.group_by("label", maintain_order=True)179.agg(pl.col("value").index_of(20))180.collect(),181expected,182)183184185LISTS_STRATEGY = st.lists(186st.one_of(st.none(), st.integers(min_value=10, max_value=50)), max_size=10187)188189190@given(191list1=LISTS_STRATEGY,192list2=LISTS_STRATEGY,193list3=LISTS_STRATEGY,194)195# The examples are cases where this test previously caught bugs:196@example([], [], [None])197@pytest.mark.slow198def test_randomized(199list1: list[int | None], list2: list[int | None], list3: list[int | None]200) -> None:201series = pl.concat(202[pl.Series(values, dtype=pl.Int8) for values in [list1, list2, list3]],203rechunk=False,204)205sorted_series = series.sort(descending=False)206sorted_series2 = series.sort(descending=True)207208# Values are between 10 and 50, plus add None and max/min range values:209for i in set(range(10, 51)) | {-128, 127, None}:210assert_index_of(series, i)211assert_index_of(sorted_series, i)212assert_index_of(sorted_series2, i)213214215ENUM = pl.Enum(["a", "b", "c"])216217218@pytest.mark.parametrize(219("series", "extra_values", "sortable"),220[221(pl.Series(["abc", None, "bb"]), ["", "🚲"], True),222(pl.Series([True, None, False, True, False]), [], True),223(224pl.Series([datetime(1997, 12, 31), datetime(1996, 1, 1)]),225[datetime(2023, 12, 12, 16, 12, 39)],226True,227),228(229pl.Series([date(1997, 12, 31), None, date(1996, 1, 1)]),230[date(2023, 12, 12)],231True,232),233(234pl.Series([time(16, 12, 31), None, time(11, 10, 53)]),235[time(11, 12, 16)],236True,237),238(239pl.Series(240[timedelta(hours=12), None, timedelta(minutes=3)],241),242[timedelta(minutes=17)],243True,244),245(pl.Series([[1, 2], None, [4, 5], [6], [None, 3, 5]]), [[5, 7], []], True),246(247pl.Series([[[1, 2]], None, [[4, 5]], [[6]], [[None, 3, 5]], [None]]),248[[[5, 7]], []],249True,250),251(252pl.Series([[1, 2], None, [4, 5], [None, 3]], dtype=pl.Array(pl.Int64(), 2)),253[[5, 7], [None, None]],254True,255),256(257pl.Series(258[[[1, 2]], [None], [[4, 5]], None, [[None, 3]]],259dtype=pl.Array(pl.Array(pl.Int64(), 2), 1),260),261[[[5, 7]], [[None, None]]],262True,263),264(265pl.Series(266[{"a": 1, "b": 2}, None, {"a": 3, "b": 4}, {"a": None, "b": 2}],267dtype=pl.Struct({"a": pl.Int64(), "b": pl.Int64()}),268),269[{"a": 7, "b": None}, {"a": 6, "b": 4}],270False,271),272(pl.Series([b"abc", None, b"xxx"]), [b"\x0025"], True),273(274pl.Series(275[Decimal(12), None, Decimal(3), Decimal(-12), Decimal(1) / Decimal(10)],276dtype=pl.Decimal(20, 4),277),278[Decimal(4), Decimal(-2), Decimal(1) / Decimal(4), Decimal(1) / Decimal(8)],279True,280),281],282)283def test_other_types(284series: pl.Series, extra_values: list[Any], sortable: bool285) -> None:286expected_values = series.to_list()287series_variants = [series, series.drop_nulls()]288if sortable:289series_variants.extend(290[291series.sort(descending=False),292series.sort(descending=True),293]294)295for s in series_variants:296for value in expected_values:297assert_index_of(s, value, convert_to_literal=True)298assert_index_of(s, value, convert_to_literal=False)299# Extra values may not be expressible as literal of correct dtype, so300# don't try:301for value in extra_values:302assert_index_of(s, value)303304305# Before the output type would be list[idx-type] when no item was found306def test_non_found_correct_type() -> None:307df = pl.DataFrame(308[309pl.Series("a", [0, 1], pl.Int32),310pl.Series("b", [1, 2], pl.Int32),311]312)313314assert_frame_equal(315df.group_by("a", maintain_order=True).agg(pl.col.b.index_of(1)),316pl.DataFrame({"a": [0, 1], "b": [0, None]}),317check_dtypes=False,318)319320321def test_error_on_multiple_values() -> None:322with pytest.raises(323pl.exceptions.InvalidOperationError,324match="needle of `index_of` can only contain",325):326pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3]))327328329@pytest.mark.parametrize(330"convert_to_literal",331[332True,333False,334],335)336def test_enum(convert_to_literal: bool) -> None:337series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"]))338expected_values = series.to_list()339for s in [340series,341series.drop_nulls(),342series.sort(descending=False),343series.sort(descending=True),344]:345for value in expected_values:346assert_index_of(s, value, convert_to_literal=convert_to_literal)347348349@pytest.mark.parametrize(350"convert_to_literal",351[True, False],352)353def test_categorical(convert_to_literal: bool) -> None:354series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)355expected_values = series.to_list()356for s in [357series,358series.drop_nulls(),359series.sort(descending=False),360series.sort(descending=True),361]:362for value in expected_values:363assert_index_of(s, value, convert_to_literal=convert_to_literal)364365366@pytest.mark.parametrize("value", [0, 0.1])367def test_categorical_wrong_type_keys_dont_work(value: int | float) -> None:368series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)369msg = "cannot cast.*losslessly.*"370with pytest.raises(InvalidOperationError, match=msg):371series.index_of(value)372df = pl.DataFrame({"s": series})373with pytest.raises(InvalidOperationError, match=msg):374df.select(pl.col("s").index_of(value))375376377@given(s=series(name="s", allow_chunks=True, max_size=10))378def test_index_of_null_parametric(s: pl.Series) -> None:379idx_null = s.index_of(None)380if s.len() == 0:381assert idx_null is None382elif s.null_count() == 0:383assert idx_null is None384elif s.null_count() == len(s):385assert idx_null == 0386387388def test_out_of_range_integers() -> None:389series = pl.Series([0, 100, None, 1, 2], dtype=pl.Int8)390with pytest.raises(InvalidOperationError, match="cannot cast 128 losslessly to i8"):391assert series.index_of(128)392with pytest.raises(393InvalidOperationError, match="cannot cast -200 losslessly to i8"394):395assert series.index_of(-200)396397398def test_out_of_range_decimal() -> None:399# Up to 34 digits of integers:400series = pl.Series([1, None], dtype=pl.Decimal(36, 2))401assert series.index_of(10**34 - 1) is None402assert series.index_of(-(10**34 - 1)) is None403out_of_range = 10**34404with pytest.raises(405InvalidOperationError, match=f"cannot cast {out_of_range} losslessly"406):407assert series.index_of(out_of_range)408with pytest.raises(409InvalidOperationError, match=f"cannot cast {-out_of_range} losslessly"410):411assert series.index_of(-out_of_range)412413414def test_out_of_range_float64() -> None:415series = pl.Series([0, 255, None], dtype=pl.Float64)416# Small numbers are fine:417assert series.index_of(1_000_000) is None418assert series.index_of(-1_000_000) is None419with pytest.raises(420InvalidOperationError, match=f"cannot cast {2**53} losslessly to f64"421):422assert series.index_of(2**53)423with pytest.raises(424InvalidOperationError, match=f"cannot cast {-(2**53)} losslessly to f64"425):426assert series.index_of(-(2**53))427428429def test_out_of_range_float32() -> None:430series = pl.Series([0, 255, None], dtype=pl.Float32)431# Small numbers are fine:432assert series.index_of(1_000_000) is None433assert series.index_of(-1_000_000) is None434with pytest.raises(435InvalidOperationError, match=f"cannot cast {2**24} losslessly to f32"436):437assert series.index_of(2**24)438with pytest.raises(439InvalidOperationError, match=f"cannot cast {-(2**24)} losslessly to f32"440):441assert series.index_of(-(2**24))442443444def assert_lossy_cast_rejected(445series_dtype: PolarsDataType, value: Any, value_dtype: PolarsDataType446) -> None:447# We create a Series with a null because previously lossless casts would448# sometimes get turned into nulls and you'd get an answer.449series = pl.Series([None], dtype=series_dtype)450with pytest.raises(InvalidOperationError, match="cannot cast losslessly"):451series.index_of(pl.lit(value, dtype=value_dtype))452453454@pytest.mark.parametrize(455("series_dtype", "value", "value_dtype"),456[457# Completely incompatible:458(pl.String, 1, pl.UInt8),459(pl.UInt8, "1", pl.String),460# Larger integer doesn't fit in smaller integer:461(pl.UInt8, 17, pl.UInt16),462# Can't find negative numbers in unsigned integers:463(pl.UInt16, -1, pl.Int8),464# Values after the decimal point that can't be represented:465(pl.Decimal(3, 1), 1, pl.Decimal(4, 2)),466# Can't fit in Decimal:467(pl.Decimal(3, 0), 1, pl.Decimal(4, 0)),468(pl.Decimal(5, 2), 1, pl.Decimal(5, 1)),469(pl.Decimal(5, 2), 1, pl.UInt16),470# Can't fit nanoseconds in milliseconds:471(pl.Duration("ms"), 1, pl.Duration("ns")),472# Arrays that are the wrong length:473(pl.Array(pl.Int64, 2), [1], pl.Array(pl.Int64, 1)),474# Struct with wrong number of fields:475(476pl.Struct({"a": pl.Int64, "b": pl.Int64}),477{"a": 1},478pl.Struct({"a": pl.Int64}),479),480# Struct with different field name:481(pl.Struct({"a": pl.Int64}), {"x": 1}, pl.Struct({"x": pl.Int64})),482],483)484def test_lossy_casts_are_rejected(485series_dtype: PolarsDataType, value: Any, value_dtype: PolarsDataType486) -> None:487assert_lossy_cast_rejected(series_dtype, value, value_dtype)488489490def test_lossy_casts_are_rejected_nested_dtypes() -> None:491# Make sure casting rules are applied recursively for Lists, Arrays,492# Struct:493series_dtype, value, value_dtype = pl.UInt8, 17, pl.UInt16494assert_lossy_cast_rejected(pl.List(series_dtype), [value], pl.List(value_dtype))495assert_lossy_cast_rejected(496pl.Array(series_dtype, 1), [value], pl.Array(value_dtype, 1)497)498assert_lossy_cast_rejected(499pl.Struct({"key": series_dtype}),500{"key": value},501pl.Struct({"key": value_dtype}),502)503504505def test_decimal_search_for_int() -> None:506values = [Decimal(-12), Decimal(12), Decimal(30)]507series = pl.Series(values, dtype=pl.Decimal(4, 1))508for i, value in enumerate(values):509assert series.index_of(value) == i510assert series.index_of(int(value)) == i511assert series.index_of(np.int8(value)) == i # type: ignore[arg-type]512# Decimal's integer range is 3 digits (3 == 4 - 1), so int8 fits:513assert series.index_of(np.int8(127)) is None # type: ignore[arg-type]514assert series.index_of(np.int8(-128)) is None # type: ignore[arg-type]515516517