Path: blob/main/py-polars/tests/unit/operations/test_gather.py
6939 views
import numpy as np1import pytest23import polars as pl4from polars.exceptions import ComputeError5from polars.testing import assert_frame_equal, assert_series_equal678def test_negative_index() -> None:9df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 6]})10assert df.select(pl.col("a").gather([0, -1])).to_dict(as_series=False) == {11"a": [1, 6]12}13assert_frame_equal(14df.group_by(pl.col("a") % 2).agg(b=pl.col("a").gather([0, -1])),15pl.DataFrame({"a": [0, 1], "b": [[2, 6], [1, 5]]}),16check_row_order=False,17)181920def test_gather_agg_schema() -> None:21df = pl.DataFrame(22{23"group": [24"one",25"one",26"one",27"two",28"two",29"two",30],31"value": [1, 98, 2, 3, 99, 4],32}33)34assert (35df.lazy()36.group_by("group", maintain_order=True)37.agg(pl.col("value").get(1))38.collect_schema()["value"]39== pl.Int6440)414243def test_gather_lit_single_16535() -> None:44df = pl.DataFrame({"x": [1, 2, 2, 1], "y": [1, 2, 3, 4]})4546assert df.group_by(["x"], maintain_order=True).agg(pl.all().gather([1])).to_dict(47as_series=False48) == {"x": [1, 2], "y": [[4], [3]]}495051def test_list_get_null_offset_17248() -> None:52df = pl.DataFrame({"material": [["PB", "PVC", "CI"], ["CI"], ["CI"]]})5354assert df.select(55result=pl.when(pl.col.material.list.len() == 1).then("material").list.get(0),56)["result"].to_list() == [None, "CI", "CI"]575859def test_list_get_null_oob_17252() -> None:60df = pl.DataFrame(61{62"name": ["BOB-3", "BOB", None],63}64)6566split = df.with_columns(pl.col("name").str.split("-"))67assert split.with_columns(pl.col("name").list.get(0))["name"].to_list() == [68"BOB",69"BOB",70None,71]727374def test_list_get_null_on_oob_false_success() -> None:75# test Series (single offset) with nulls76expected = pl.Series("a", [2, None, 2], dtype=pl.Int64)77s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])78out = s_nulls.list.get(1, null_on_oob=False)79assert_series_equal(out, expected)8081# test Expr (multiple offsets) with nulls82df = s_nulls.to_frame().with_columns(pl.lit(1).alias("idx"))83out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()84assert_series_equal(out, expected)8586# test Series (single offset) with no nulls87expected = pl.Series("a", [2, 2, 2], dtype=pl.Int64)88s_no_nulls = pl.Series("a", [[1, 2], [1, 2], [1, 2, 3]])89out = s_no_nulls.list.get(1, null_on_oob=False)90assert_series_equal(out, expected)9192# test Expr (multiple offsets) with no nulls93df = s_no_nulls.to_frame().with_columns(pl.lit(1).alias("idx"))94out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()95assert_series_equal(out, expected)969798def test_list_get_null_on_oob_false_failure() -> None:99# test Series (single offset) with nulls100s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])101with pytest.raises(ComputeError, match="get index is out of bounds"):102s_nulls.list.get(2, null_on_oob=False)103104# test Expr (multiple offsets) with nulls105df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))106with pytest.raises(ComputeError, match="get index is out of bounds"):107df.select(pl.col("a").list.get("idx", null_on_oob=False))108109# test Series (single offset) with no nulls110s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]])111with pytest.raises(ComputeError, match="get index is out of bounds"):112s_no_nulls.list.get(2, null_on_oob=False)113114# test Expr (multiple offsets) with no nulls115df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))116with pytest.raises(ComputeError, match="get index is out of bounds"):117df.select(pl.col("a").list.get("idx", null_on_oob=False))118119120def test_list_get_null_on_oob_true() -> None:121# test Series (single offset) with nulls122s_nulls = pl.Series("a", [[1, 2], None, [1, 2, 3]])123out = s_nulls.list.get(2, null_on_oob=True)124expected = pl.Series("a", [None, None, 3], dtype=pl.Int64)125assert_series_equal(out, expected)126127# test Expr (multiple offsets) with nulls128df = s_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))129out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()130assert_series_equal(out, expected)131132# test Series (single offset) with no nulls133s_no_nulls = pl.Series("a", [[1, 2], [1], [1, 2, 3]])134out = s_no_nulls.list.get(2, null_on_oob=True)135expected = pl.Series("a", [None, None, 3], dtype=pl.Int64)136assert_series_equal(out, expected)137138# test Expr (multiple offsets) with no nulls139df = s_no_nulls.to_frame().with_columns(pl.lit(2).alias("idx"))140out = df.select(pl.col("a").list.get("idx", null_on_oob=True)).to_series()141assert_series_equal(out, expected)142143144def test_chunked_gather_phys_repr_17446() -> None:145dfa = pl.DataFrame({"replace_unique_id": range(2)})146147for dt in [pl.Date, pl.Time, pl.Duration]:148dfb = dfa.clone()149dfb = dfb.with_columns(ds_start_date_right=pl.lit(None).cast(dt))150dfb = pl.concat([dfb, dfb])151152assert dfa.join(dfb, how="left", on=pl.col("replace_unique_id")).shape == (4, 2)153154155def test_gather_str_col_18099() -> None:156df = pl.DataFrame({"foo": [1, 2, 3], "idx": [0, 0, 1]})157assert df.with_columns(pl.col("foo").gather("idx")).to_dict(as_series=False) == {158"foo": [1, 1, 2],159"idx": [0, 0, 1],160}161162163def test_gather_list_19243() -> None:164df = pl.DataFrame({"a": [[0.1, 0.2, 0.3]]})165assert df.with_columns(pl.lit([0]).alias("c")).with_columns(166gather=pl.col("a").list.gather(pl.col("c"), null_on_oob=True)167).to_dict(as_series=False) == {168"a": [[0.1, 0.2, 0.3]],169"c": [[0]],170"gather": [[0.1]],171}172173174def test_gather_array_list_null_19302() -> None:175data = pl.DataFrame(176{"data": [None]}, schema_overrides={"data": pl.List(pl.Array(pl.Float32, 1))}177)178assert data.select(pl.col("data").list.get(0)).to_dict(as_series=False) == {179"data": [None]180}181182183def test_gather_array() -> None:184a = np.arange(16).reshape(-1, 2, 2)185s = pl.Series(a)186187for idx in [[1, 2], [0, 0], [1, 0], [1, 1, 1, 1, 1, 1, 1, 1]]:188assert (s.gather(idx).to_numpy() == a[idx]).all()189190v = s[[0, 1, None, 3]] # type: ignore[list-item]191assert v[2] is None192193194def test_gather_array_outer_validity_19482() -> None:195s = (196pl.Series([[1], [1]], dtype=pl.Array(pl.Int64, 1))197.to_frame()198.select(pl.when(pl.int_range(pl.len()) == 0).then(pl.first()))199.to_series()200)201202expect = pl.Series([[1], None], dtype=pl.Array(pl.Int64, 1))203assert_series_equal(s, expect)204assert_series_equal(s.gather([0, 1]), expect)205206207def test_gather_len_19561() -> None:208N = 4209df = pl.DataFrame({"foo": ["baz"] * N, "bar": range(N)})210idxs = pl.int_range(1, N).repeat_by(pl.int_range(1, N)).flatten()211gather = pl.col.bar.gather(idxs).alias("gather")212213assert df.group_by("foo").agg(gather.len()).to_dict(as_series=False) == {214"foo": ["baz"],215"gather": [6],216}217218219def test_gather_agg_group_update_scalar() -> None:220# If `gather` doesn't update groups properly, `first` will try to access221# index 2 (the original index of the first element of group `1`), but gather222# outputs only two elements (one for each group), leading to an out of223# bounds access.224df = (225pl.DataFrame({"gid": [0, 0, 1, 1], "x": ["0:0", "0:1", "1:0", "1:1"]})226.lazy()227.group_by("gid", maintain_order=True)228.agg(x_at_gid=pl.col("x").gather(pl.col("gid").last()).first())229.collect(optimizations=pl.QueryOptFlags.none())230)231expected = pl.DataFrame({"gid": [0, 1], "x_at_gid": ["0:0", "1:1"]})232assert_frame_equal(df, expected)233234235def test_gather_agg_group_update_literal() -> None:236# If `gather` doesn't update groups properly, `first` will try to access237# index 2 (the original index of the first element of group `1`), but gather238# outputs only two elements (one for each group), leading to an out of239# bounds access.240df = (241pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]})242.lazy()243.group_by("gid", maintain_order=True)244.agg(x_at_0=pl.col("x").gather(0).first())245.collect(optimizations=pl.QueryOptFlags.none())246)247expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]})248assert_frame_equal(df, expected)249250251def test_gather_agg_group_update_negative() -> None:252# If `gather` doesn't update groups properly, `first` will try to access253# index 2 (the original index of the first element of group `1`), but gather254# outputs only two elements (one for each group), leading to an out of255# bounds access.256df = (257pl.DataFrame({"gid": [0, 0, 1], "x": ["0:0", "0:1", "1:0"]})258.lazy()259.group_by("gid", maintain_order=True)260.agg(x_last=pl.col("x").gather(-1).first())261.collect(optimizations=pl.QueryOptFlags.none())262)263expected = pl.DataFrame({"gid": [0, 1], "x_last": ["0:1", "1:0"]})264assert_frame_equal(df, expected)265266267def test_gather_agg_group_update_multiple() -> None:268# If `gather` doesn't update groups properly, `first` will try to access269# index 4 (the original index of the first element of group `1`), but gather270# outputs only four elements (two for each group), leading to an out of271# bounds access.272df = (273pl.DataFrame(274{275"gid": [0, 0, 0, 0, 1, 1],276"x": ["0:0", "0:1", "0:2", "0:3", "1:0", "1:1"],277}278)279.lazy()280.group_by("gid", maintain_order=True)281.agg(x_at_0=pl.col("x").gather([0, 1]).first())282.collect(optimizations=pl.QueryOptFlags.none())283)284expected = pl.DataFrame({"gid": [0, 1], "x_at_0": ["0:0", "1:0"]})285assert_frame_equal(df, expected)286287288def test_get_agg_group_update_literal_21610() -> None:289df = (290pl.DataFrame(291{292"group": [100, 100, 100, 200, 200, 200],293"value": [1, 2, 3, 2, 3, 4],294}295)296.group_by("group", maintain_order=True)297.agg(pl.col("value") - pl.col("value").get(0))298)299300expected = pl.DataFrame({"group": [100, 200], "value": [[0, 1, 2], [0, 1, 2]]})301assert_frame_equal(df, expected)302303304def test_get_agg_group_update_scalar_21610() -> None:305df = (306pl.DataFrame(307{308"group": [100, 100, 100, 200, 200, 200],309"value": [1, 2, 3, 2, 3, 4],310}311)312.group_by("group", maintain_order=True)313.agg(pl.col("value") - pl.col("value").get(pl.col("value").first()))314)315316expected = pl.DataFrame({"group": [100, 200], "value": [[-1, 0, 1], [-2, -1, 0]]})317assert_frame_equal(df, expected)318319320def test_get_dt_truncate_21533() -> None:321df = pl.DataFrame(322{323"timestamp": pl.datetime_range(324pl.datetime(2016, 1, 1),325pl.datetime(2017, 12, 31),326interval="1d",327eager=True,328),329}330).with_columns(331month=pl.col.timestamp.dt.month(),332)333334report = df.group_by("month", maintain_order=True).agg(335trunc_ts=pl.col.timestamp.get(0).dt.truncate("1m")336)337assert report.shape == (12, 2)338339340@pytest.mark.parametrize("maintain_order", [False, True])341def test_gather_group_by_23696(maintain_order: bool) -> None:342df = (343pl.DataFrame(344{345"a": [1, 2, 3, 4],346"b": [0, 0, 1, 1],347"c": [0, 0, -1, -1],348}349)350.group_by(pl.col.a % 2, maintain_order=maintain_order)351.agg(352get_first=pl.col.a.get(pl.col.b.get(0)),353get_last=pl.col.a.get(pl.col.b.get(1)),354normal=pl.col.a.gather(pl.col.b),355signed=pl.col.a.gather(pl.col.c),356drop_nulls=pl.col.a.gather(pl.col.b.drop_nulls()),357drop_nulls_signed=pl.col.a.gather(pl.col.c.drop_nulls()),358literal=pl.col.a.gather([0, 1]),359literal_signed=pl.col.a.gather([0, -1]),360)361)362363expected = pl.DataFrame(364{365"a": [1, 0],366"get_first": [1, 2],367"get_last": [3, 4],368"normal": [[1, 3], [2, 4]],369"signed": [[1, 3], [2, 4]],370"drop_nulls": [[1, 3], [2, 4]],371"drop_nulls_signed": [[1, 3], [2, 4]],372"literal": [[1, 3], [2, 4]],373"literal_signed": [[1, 3], [2, 4]],374}375)376377assert_frame_equal(df, expected, check_row_order=maintain_order)378379380def test_gather_invalid_indices_groupby_24182() -> None:381df = pl.DataFrame({"x": [1, 2]})382with pytest.raises(pl.exceptions.InvalidOperationError):383df.group_by(True).agg(pl.col("x").gather(pl.lit("y")))384385386@pytest.mark.parametrize("maintain_order", [False, True])387def test_gather_group_by_lit(maintain_order: bool) -> None:388assert_frame_equal(389pl.DataFrame(390{391"a": [1, 2, 3],392}393)394.group_by("a", maintain_order=maintain_order)395.agg(pl.lit([1]).gather([0, 0, 0])),396pl.DataFrame({"a": [1, 2, 3], "literal": [[[1], [1], [1]]] * 3}),397check_row_order=maintain_order,398)399400401