Path: blob/main/py-polars/tests/unit/operations/test_gather.py
8424 views
import numpy as np1import pytest23import polars as pl4from polars.exceptions import ComputeError, OutOfBoundsError5from polars.testing import assert_frame_equal, assert_series_equal678def test_negative_index() -> None:9df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 6]})10assert df.select(pl.col("a").gather([0, -1])).to_dict(as_series=False) == {11"a": [1, 6]12}13assert_frame_equal(14df.group_by(pl.col("a") % 2).agg(b=pl.col("a").gather([0, -1])),15pl.DataFrame({"a": [0, 1], "b": [[2, 6], [1, 5]]}),16check_row_order=False,17)181920def test_gather_agg_schema() -> None:21df = pl.DataFrame(22{23"group": [24"one",25"one",26"one",27"two",28"two",29"two",30],31"value": [1, 98, 2, 3, 99, 4],32}33)34assert (35df.lazy()36.group_by("group", maintain_order=True)37.agg(pl.col("value").get(1))38.collect_schema()["value"]39== pl.Int6440)414243def test_gather_lit_single_16535() -> None:44df = pl.DataFrame({"x": [1, 2, 2, 1], "y": [1, 2, 3, 4]})4546assert df.group_by(["x"], maintain_order=True).agg(pl.all().gather([1])).to_dict(47as_series=False48) == {"x": [1, 2], "y": [[4], [3]]}495051def test_list_get_null_offset_17248() -> None:52df = pl.DataFrame({"material": [["PB", "PVC", "CI"], ["CI"], ["CI"]]})5354assert df.select(55result=pl.when(pl.col.material.list.len() == 1).then("material").list.get(0),56)["result"].to_list() == [None, "CI", "CI"]575859def test_list_get_null_oob_17252() -> None:60df = pl.DataFrame(61{62"name": ["BOB-3", "BOB", None],63}64)6566split = df.with_columns(pl.col("name").str.split("-"))67assert split.with_columns(pl.col("name").list.get(0))["name"].to_list() == [68"BOB",69"BOB",70None,71]727374def test_list_get_null_on_oob_false_success() -> None:75# test Series (single offset) with nulls76expected = pl.Series("a", [2, None, 2], dtype=pl.Int64)77s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])78out = s_nulls.list.get(1, null_on_oob=False)79assert_series_equal(out, expected)8081# test Expr (multiple offsets) with nulls82df = s_nulls.to_frame().with_columns(pl.lit(1).alias("idx"))83out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()84assert_series_equal(out, expected)8586# test Series (single offset) with no nulls87expected = pl.Series("a", [2, 2, 2], dtype=pl.Int64)88s_no_nulls = pl.Series("a", [[1, 2], [1, 2], [1, 2, 3]])89out = s_no_nulls.list.get(1, null_on_oob=False)90assert_series_equal(out, expected)9192# test Expr (multiple offsets) with no nulls93df = s_no_nulls.to_frame().with_columns(pl.lit(1).alias("idx"))94out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()95assert_series_equal(out, expected)969798def test_list_get_null_on_oob_false_failure() -> None:99# test Series (single offset) with nulls100s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])101with pytest.raises(ComputeError, match="get index is out of bounds"):102s_nulls.list.get(2, null_on_oob=False)103104# test Expr (multiple offsets) with nulls105df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))106with pytest.raises(ComputeError, match="get index is out of bounds"):107df.select(pl.col("a").list.get("idx", null_on_oob=False))108109# test Series (single offset) with no nulls110s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]])111with pytest.raises(ComputeError, match="get index is out of bounds"):112s_no_nulls.list.get(2, null_on_oob=False)113114# test Expr (multiple offsets) with no nulls115df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))116with pytest.raises(ComputeError, match="get index is out of bounds"):117df.select(pl.col("a").list.get("idx", null_on_oob=False))118119120def test_list_get_null_on_oob_true() -> None:121# test Series (single offset) with nulls122s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])123out = s_nulls.list.get(2, null_on_oob=True)124expected = pl.Series("a", [None, None, 3], dtype=pl.Int64)125assert_series_equal(out, expected)126127# test Expr (multiple offsets) with nulls128df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))129out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()130assert_series_equal(out, expected)131132# test Series (single offset) with no nulls133s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]])134out = s_no_nulls.list.get(2, null_on_oob=True)135expected = pl.Series("a", [None, None, 3], dtype=pl.Int64)136assert_series_equal(out, expected)137138# test Expr (multiple offsets) with no nulls139df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))140out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()141assert_series_equal(out, expected)142143144def test_chunked_gather_phys_repr_17446() -> None:145dfa = pl.DataFrame({"replace_unique_id": range(2)})146147for dt in [pl.Date, pl.Time, pl.Duration]:148dfb = dfa.clone()149dfb = dfb.with_columns(ds_start_date_right=pl.lit(None).cast(dt))150dfb = pl.concat([dfb, dfb])151152assert dfa.join(dfb, how="left", on=pl.col("replace_unique_id")).shape == (4, 2)153154155def test_gather_str_col_18099() -> None:156df = pl.DataFrame({"foo": [1, 2, 3], "idx": [0, 0, 1]})157assert df.with_columns(pl.col("foo").gather("idx")).to_dict(as_series=False) == {158"foo": [1, 1, 2],159"idx": [0, 0, 1],160}161162163def test_gather_list_19243() -> None:164df = pl.DataFrame({"a": [[0.1, 0.2, 0.3]]})165assert df.with_columns(pl.lit([0]).alias("c")).with_columns(166gather=pl.col("a").list.gather(pl.col("c"), null_on_oob=True)167).to_dict(as_series=False) == {168"a": [[0.1, 0.2, 0.3]],169"c": [[0]],170"gather": [[0.1]],171}172173174def test_gather_array_list_null_19302() -> None:175data = pl.DataFrame(176{"data": [None]}, schema_overrides={"data": pl.List(pl.Array(pl.Float32, 1))}177)178assert data.select(pl.col("data").list.get(0)).to_dict(as_series=False) == {179"data": [None]180}181182183def test_gather_array() -> None:184a = np.arange(16).reshape(-1, 2, 2)185s = pl.Series(a)186187for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]:188assert (s.gather(idx).to_numpy() == a[idx]).all()189190v = s[[0, 1, None, 3]] # type: ignore[list-item]191assert v[2] is None192193194def test_gather_array_outer_validity_19482() -> None:195s = (196pl.Series([[1], [1]], dtype=pl.Array(pl.Int64, 1))197.to_frame()198.select(pl.when(pl.int_range(pl.len()) == 0).then(pl.first()))199.to_series()200)201202expect = pl.Series([[1], None], dtype=pl.Array(pl.Int64, 1))203assert_series_equal(s, expect)204assert_series_equal(s.gather([0, 1]), expect)205206207def test_gather_len_19561() -> None:208N = 4209df = pl.DataFrame({"foo": ["baz"] * N, "bar": range(N)})210211idxs = (212pl.int_range(1, N)213.repeat_by(pl.int_range(1, N))214.list.explode(keep_nulls=False, empty_as_null=False)215)216gather = pl.col("bar").gather(idxs).alias("gather")217218assert df.group_by("foo").agg(gather.len()).to_dict(as_series=False) == {219"foo": ["baz"],220"gather": [6],221}222223224def test_gather_agg_group_update_scalar() -> None:225# If `gather` doesn't update groups properly, `first` will try to access226# index 2 (the original index of the first element of group `1`), but gather227# outputs only two elements (one for each group), leading to an out of228# bounds access.229df = (230pl.DataFrame({"gid": [0, 0, 1, 1], "x": ["0:0", "0:1", "1:0", "1:1"]})231.lazy()232.group_by("gid", maintain_order=True)233.agg(x_at_gid=pl.col("x").gather(pl.col("gid").last()).first())234.collect(optimizations=pl.QueryOptFlags.none())235)236expected = pl.DataFrame({"gid": [0, 1], "x_at_gid": ["0:0", "1:1"]})237assert_frame_equal(df, expected)238239240def test_gather_agg_group_update_literal() -> None:241# If `gather` doesn't update groups properly, `first` will try to access242# index 2 (the original index of the first element of group `1`), but gather243# outputs only two elements (one for each group), leading to an out of244# bounds access.245df = (246pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]})247.lazy()248.group_by("gid", maintain_order=True)249.agg(x_at_0=pl.col("x").gather(0).first())250.collect(optimizations=pl.QueryOptFlags.none())251)252expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]})253assert_frame_equal(df, expected)254255256def test_gather_agg_group_update_negative() -> None:257# If `gather` doesn't update groups properly, `first` will try to access258# index 2 (the original index of the first element of group `1`), but gather259# outputs only two elements (one for each group), leading to an out of260# bounds access.261df = (262pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]})263.lazy()264.group_by("gid", maintain_order=True)265.agg(x_last=pl.col("x").gather(-1).first())266.collect(optimizations=pl.QueryOptFlags.none())267)268expected = pl.DataFrame({"gid": [0, 1], "x_last": ["0:1", "1:0"]})269assert_frame_equal(df, expected)270271272def test_gather_agg_group_update_multiple() -> None:273# If `gather` doesn't update groups properly, `first` will try to access274# index 4 (the original index of the first element of group `1`), but gather275# outputs only four elements (two for each group), leading to an out of276# bounds access.277df = (278pl.DataFrame(279{280"gid": [0, 0, 0, 0, 1, 1],281"x": ["0:0", "0:1", "0:2", "0:3", "1:0", "1:1"],282}283)284.lazy()285.group_by("gid", maintain_order=True)286.agg(x_at_0=pl.col("x").gather([0, 1]).first())287.collect(optimizations=pl.QueryOptFlags.none())288)289expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]})290assert_frame_equal(df, expected)291292293def test_get_agg_group_update_literal_21610() -> None:294df = (295pl.DataFrame(296{297"group": [100, 100, 100, 200, 200, 200],298"value": [1, 2, 3, 2, 3, 4],299}300)301.group_by("group", maintain_order=True)302.agg(pl.col("value") - pl.col("value").get(0))303)304305expected = pl.DataFrame({"group": [100, 200], "value": [[0, 1, 2], [0, 1, 2]]})306assert_frame_equal(df, expected)307308309def test_get_agg_group_update_scalar_21610() -> None:310df = (311pl.DataFrame(312{313"group": [100, 100, 100, 200, 200, 200],314"value": [1, 2, 3, 2, 3, 4],315}316)317.group_by("group", maintain_order=True)318.agg(pl.col("value") - pl.col("value").get(pl.col("value").first()))319)320321expected = pl.DataFrame({"group": [100, 200], "value": [[-1, 0, 1], [-2, -1, 0]]})322assert_frame_equal(df, expected)323324325def test_get_dt_truncate_21533() -> None:326df = pl.DataFrame(327{328"timestamp": pl.datetime_range(329pl.datetime(2016, 1, 1),330pl.datetime(2017, 12, 31),331interval="1d",332eager=True,333),334}335).with_columns(336month=pl.col.timestamp.dt.month(),337)338339report = df.group_by("month", maintain_order=True).agg(340trunc_ts=pl.col.timestamp.get(0).dt.truncate("1m")341)342assert report.shape == (12, 2)343344345@pytest.mark.parametrize("maintain_order", [False, True])346def test_gather_group_by_23696(maintain_order: bool) -> None:347df = (348pl.DataFrame(349{350"a": [1, 2, 3, 4],351"b": [0, 0, 1, 1],352"c": [0, 0, -1, -1],353}354)355.group_by(pl.col.a % 2, maintain_order=maintain_order)356.agg(357get_first=pl.col.a.get(pl.col.b.get(0)),358get_last=pl.col.a.get(pl.col.b.get(1)),359normal=pl.col.a.gather(pl.col.b),360signed=pl.col.a.gather(pl.col.c),361drop_nulls=pl.col.a.gather(pl.col.b.drop_nulls()),362drop_nulls_signed=pl.col.a.gather(pl.col.c.drop_nulls()),363literal=pl.col.a.gather([0, 1]),364literal_signed=pl.col.a.gather([0, -1]),365)366)367368expected = pl.DataFrame(369{370"a": [1, 0],371"get_first": [1, 2],372"get_last": [3, 4],373"normal": [[1, 3], [2, 4]],374"signed": [[1, 3], [2, 4]],375"drop_nulls": [[1, 3], [2, 4]],376"drop_nulls_signed": [[1, 3], [2, 4]],377"literal": [[1, 3], [2, 4]],378"literal_signed": [[1, 3], [2, 4]],379}380)381382assert_frame_equal(df, expected, check_row_order=maintain_order)383384385def test_gather_invalid_indices_groupby_24182() -> None:386df = pl.DataFrame({"x": [1, 2]})387with pytest.raises(pl.exceptions.InvalidOperationError):388df.group_by(True).agg(pl.col("x").gather(pl.lit("y")))389390391@pytest.mark.parametrize("maintain_order", [False, True])392def test_gather_group_by_lit(maintain_order: bool) -> None:393assert_frame_equal(394pl.DataFrame(395{396"a": [1, 2, 3],397}398)399.group_by("a", maintain_order=maintain_order)400.agg(pl.lit([1]).gather([0, 0, 0])),401pl.DataFrame({"a": [1, 2, 3], "literal": [[[1], [1], [1]]] * 3}),402check_row_order=maintain_order,403)404405406def test_get_window_with_filtered_empty_groups_23029() -> None:407# https://github.com/pola-rs/polars/issues/23029408df = pl.DataFrame(409{410"group": [1, 1, 2, 2, 3, 3],411"value": [10, 20, 30, 40, 50, 60],412"filter_condition": [False, True, False, False, True, True],413}414)415416result = df.with_columns(417get_first=(418pl.col("value")419.filter(pl.col("filter_condition"))420.get(0, null_on_oob=True)421.over("group")422),423first_value=(424pl.col("value").filter(pl.col("filter_condition")).first().over("group")425),426)427428assert_series_equal(429result["get_first"],430result["first_value"],431check_names=False,432)433434# And the concrete expected values are:435expected = pl.DataFrame(436{437"group": [1, 1, 2, 2, 3, 3],438"value": [10, 20, 30, 40, 50, 60],439"filter_condition": [False, True, False, False, True, True],440"get_first": [20, 20, None, None, 50, 50],441"first_value": [20, 20, None, None, 50, 50],442}443)444445assert_frame_equal(result, expected)446447448@pytest.mark.parametrize("idx_dtype", [pl.Int64, pl.UInt64, pl.Int128, pl.UInt128])449def test_get_typed_index_null_on_oob_true(idx_dtype: pl.DataType) -> None:450# OOB typed index with null_on_oob=True -> null, for multiple integer dtypes.451df = pl.DataFrame({"value": [1, 2, 10]})452453out = df.select(v=pl.col("value").get(pl.lit(5, dtype=idx_dtype), null_on_oob=True))454455assert out["v"].to_list() == [None]456457458@pytest.mark.parametrize("idx_dtype", [pl.Int64, pl.UInt64, pl.Int128, pl.UInt128])459def test_get_typed_index_null_on_oob_false_raises(idx_dtype: pl.DataType) -> None:460# OOB typed index with null_on_oob=False -> OutOfBoundsError, for multiple dtypes.461df = pl.DataFrame({"value": [10, 11]})462463with pytest.raises(OutOfBoundsError, match="gather indices are out of bounds"):464df.select(pl.col("value").get(pl.lit(5, dtype=idx_dtype), null_on_oob=False))465466467@pytest.mark.parametrize("idx_dtype", [pl.Int64, pl.UInt64, pl.Int128, pl.UInt128])468def test_get_typed_index_default_raises_out_of_bounds(idx_dtype: pl.DataType) -> None:469# Default behavior (null_on_oob omitted) should behave like null_on_oob=False470df = pl.DataFrame({"value": [10, 11]})471472with pytest.raises(OutOfBoundsError, match="gather indices are out of bounds"):473df.select(pl.col("value").get(pl.lit(5, dtype=idx_dtype)))474475476