Path: blob/main/py-polars/tests/unit/operations/namespaces/array/test_array.py
6940 views
from __future__ import annotations12import datetime3from typing import Any45import pytest67import polars as pl8from polars.exceptions import ComputeError, InvalidOperationError9from polars.testing import assert_frame_equal, assert_series_equal101112def test_arr_min_max() -> None:13s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2))14assert s.arr.max().to_list() == [2, 4]15assert s.arr.min().to_list() == [1, 3]1617s_with_null = pl.Series("a", [[None, 2], None, [3, 4]], dtype=pl.Array(pl.Int64, 2))18assert s_with_null.arr.max().to_list() == [2, None, 4]19assert s_with_null.arr.min().to_list() == [2, None, 3]202122def test_array_min_max_dtype_12123() -> None:23df = pl.LazyFrame(24[pl.Series("a", [[1.0, 3.0], [2.0, 5.0]]), pl.Series("b", [1.0, 2.0])],25schema_overrides={26"a": pl.Array(pl.Float64, 2),27},28)2930df = df.with_columns(31max=pl.col("a").arr.max().alias("max"),32min=pl.col("a").arr.min().alias("min"),33)3435assert df.collect_schema() == {36"a": pl.Array(pl.Float64, 2),37"b": pl.Float64,38"max": pl.Float64,39"min": pl.Float64,40}4142out = df.select(pl.col("max") * pl.col("b"), pl.col("min") * pl.col("b")).collect()4344assert_frame_equal(out, pl.DataFrame({"max": [3.0, 10.0], "min": [1.0, 4.0]}))454647@pytest.mark.parametrize(48("data", "expected_sum", "dtype"),49[50([[1, 2], [4, 3]], [3, 7], pl.Int64),51([[1, None], [None, 3], [None, None]], [1, 3, 0], pl.Int64),52([[1.0, 2.0], [4.0, 3.0]], [3.0, 7.0], pl.Float32),53([[1.0, None], [None, 3.0], [None, None]], [1.0, 3.0, 0], pl.Float32),54([[True, False], [True, True], [False, False]], [1, 2, 0], pl.Boolean),55([[True, None], [None, False], [None, None]], [1, 0, 0], pl.Boolean),56],57)58def test_arr_sum(59data: list[list[Any]], expected_sum: list[Any], dtype: pl.DataType60) -> None:61s = pl.Series("a", data, dtype=pl.Array(dtype, 2))62assert s.arr.sum().to_list() == expected_sum636465@pytest.mark.may_fail_cloud66def test_array_lengths_zwa() -> None:67assert pl.Series("a", [[], []], pl.Array(pl.Null, 0)).arr.len().to_list() == [0, 0]68assert pl.Series("a", [None, []], pl.Array(pl.Null, 0)).arr.len().to_list() == [69None,700,71]72assert pl.Series("a", [None], pl.Array(pl.Null, 0)).arr.len().to_list() == [None]7374assert pl.Series("a", [], pl.Array(pl.Null, 0)).arr.len().to_list() == []757677def test_array_lengths() -> None:78df = pl.DataFrame(79[80pl.Series("a", [[1, 2, 3]], dtype=pl.Array(pl.Int64, 3)),81pl.Series("b", [[4, 5]], dtype=pl.Array(pl.Int64, 2)),82]83)84out = df.select(pl.col("a").arr.len(), pl.col("b").arr.len())85expected_df = pl.DataFrame(86{"a": [3], "b": [2]}, schema={"a": pl.UInt32, "b": pl.UInt32}87)88assert_frame_equal(out, expected_df)8990assert pl.Series("a", [], pl.Array(pl.Null, 1)).arr.len().to_list() == []91assert pl.Series(92"a", [[1, 2, 3], None, [7, 8, 9]], pl.Array(pl.Int32, 3)93).arr.len().to_list() == [3, None, 3]949596@pytest.mark.parametrize(97("as_array"),98[True, False],99)100def test_arr_slice(as_array: bool) -> None:101df = pl.DataFrame(102{103"arr": [[1, 2, 3], [10, 2, 1]],104},105schema={"arr": pl.Array(pl.Int64, 3)},106)107108assert df.select([pl.col("arr").arr.slice(0, 1, as_array=as_array)]).to_dict(109as_series=False110) == {"arr": [[1], [10]]}111assert df.select([pl.col("arr").arr.slice(1, 1, as_array=as_array)]).to_dict(112as_series=False113) == {"arr": [[2], [2]]}114assert df.select([pl.col("arr").arr.slice(-1, 1, as_array=as_array)]).to_dict(115as_series=False116) == {"arr": [[3], [1]]}117assert df.select([pl.col("arr").arr.slice(-2, 1, as_array=as_array)]).to_dict(118as_series=False119) == {"arr": [[2], [2]]}120assert df.select([pl.col("arr").arr.slice(-2, 2, as_array=as_array)]).to_dict(121as_series=False122) == {"arr": [[2, 3], [2, 1]]}123return124125126@pytest.mark.parametrize(127("as_array"),128[True, False],129)130def test_arr_slice_on_series(as_array: bool) -> None:131vals = [[1, 2, 3, 4], [10, 2, 1, 2]]132s = pl.Series("a", vals, dtype=pl.Array(pl.Int64, 4))133assert s.arr.head(2, as_array=as_array).to_list() == [[1, 2], [10, 2]]134assert s.arr.tail(2, as_array=as_array).to_list() == [[3, 4], [1, 2]]135assert s.arr.tail(10, as_array=as_array).to_list() == vals136assert s.arr.head(10, as_array=as_array).to_list() == vals137assert s.arr.slice(1, 2, as_array=as_array).to_list() == [[2, 3], [2, 1]]138assert s.arr.slice(-5, 2, as_array=as_array).to_list() == [[1], [10]]139# TODO: there is a bug in list.slice that does not allow negative values for head140if as_array:141assert s.arr.tail(-1, as_array=as_array).to_list() == [[2, 3, 4], [2, 1, 2]]142assert s.arr.tail(-2, as_array=as_array).to_list() == [[3, 4], [1, 2]]143assert s.arr.tail(-3, as_array=as_array).to_list() == [[4], [2]]144assert s.arr.head(-1, as_array=as_array).to_list() == [[1, 2, 3], [10, 2, 1]]145assert s.arr.head(-2, as_array=as_array).to_list() == [[1, 2], [10, 2]]146assert s.arr.head(-3, as_array=as_array).to_list() == [[1], [10]]147148149def test_arr_unique() -> None:150df = pl.DataFrame(151{"a": pl.Series("a", [[1, 1], [4, 3]], dtype=pl.Array(pl.Int64, 2))}152)153154out = df.select(pl.col("a").arr.unique(maintain_order=True))155expected = pl.DataFrame({"a": [[1], [4, 3]]})156assert_frame_equal(out, expected)157158159def test_array_any_all() -> None:160s = pl.Series(161[[True, True], [False, True], [False, False], [None, None], None],162dtype=pl.Array(pl.Boolean, 2),163)164165expected_any = pl.Series([True, True, False, False, None])166assert_series_equal(s.arr.any(), expected_any)167168expected_all = pl.Series([True, False, False, True, None])169assert_series_equal(s.arr.all(), expected_all)170171s = pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(pl.Int64, 2))172with pytest.raises(ComputeError, match="expected boolean elements in array"):173s.arr.any()174with pytest.raises(ComputeError, match="expected boolean elements in array"):175s.arr.all()176177178def test_array_sort() -> None:179s = pl.Series([[2, None, 1], [1, 3, 2]], dtype=pl.Array(pl.UInt32, 3))180181desc = s.arr.sort(descending=True)182expected = pl.Series([[None, 2, 1], [3, 2, 1]], dtype=pl.Array(pl.UInt32, 3))183assert_series_equal(desc, expected)184185asc = s.arr.sort(descending=False)186expected = pl.Series([[None, 1, 2], [1, 2, 3]], dtype=pl.Array(pl.UInt32, 3))187assert_series_equal(asc, expected)188189# test nulls_last190s = pl.Series([[None, 1, 2], [-1, None, 9]], dtype=pl.Array(pl.Int8, 3))191assert_series_equal(192s.arr.sort(nulls_last=True),193pl.Series([[1, 2, None], [-1, 9, None]], dtype=pl.Array(pl.Int8, 3)),194)195assert_series_equal(196s.arr.sort(nulls_last=False),197pl.Series([[None, 1, 2], [None, -1, 9]], dtype=pl.Array(pl.Int8, 3)),198)199200201def test_array_reverse() -> None:202s = pl.Series([[2, None, 1], [1, None, 2]], dtype=pl.Array(pl.UInt32, 3))203204s = s.arr.reverse()205expected = pl.Series([[1, None, 2], [2, None, 1]], dtype=pl.Array(pl.UInt32, 3))206assert_series_equal(s, expected)207208209def test_array_arg_min_max() -> None:210s = pl.Series("a", [[1, 2, 4], [3, 2, 1]], dtype=pl.Array(pl.UInt32, 3))211expected = pl.Series("a", [0, 2], dtype=pl.UInt32)212assert_series_equal(s.arr.arg_min(), expected)213expected = pl.Series("a", [2, 0], dtype=pl.UInt32)214assert_series_equal(s.arr.arg_max(), expected)215216217def test_array_get() -> None:218s = pl.Series(219"a",220[[1, 2, 3, 4], [5, 6, None, None], [7, 8, 9, 10]],221dtype=pl.Array(pl.Int64, 4),222)223224# Test index literal.225out = s.arr.get(1, null_on_oob=False)226expected = pl.Series("a", [2, 6, 8], dtype=pl.Int64)227assert_series_equal(out, expected)228229# Null index literal.230out_df = s.to_frame().select(pl.col.a.arr.get(pl.lit(None), null_on_oob=False))231expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame()232assert_frame_equal(out_df, expected_df)233234# Out-of-bounds index literal.235with pytest.raises(ComputeError, match="get index is out of bounds"):236out = s.arr.get(100, null_on_oob=False)237238# Negative index literal.239out = s.arr.get(-2, null_on_oob=False)240expected = pl.Series("a", [3, None, 9], dtype=pl.Int64)241assert_series_equal(out, expected)242243# Test index expr.244with pytest.raises(ComputeError, match="get index is out of bounds"):245out = s.arr.get(pl.Series([1, -2, 100]), null_on_oob=False)246247out = s.arr.get(pl.Series([1, -2, 0]), null_on_oob=False)248expected = pl.Series("a", [2, None, 7], dtype=pl.Int64)249assert_series_equal(out, expected)250251# Test logical type.252s = pl.Series(253"a",254[255[datetime.date(1999, 1, 1), datetime.date(2000, 1, 1)],256[datetime.date(2001, 10, 1), None],257[None, None],258],259dtype=pl.Array(pl.Date, 2),260)261with pytest.raises(ComputeError, match="get index is out of bounds"):262out = s.arr.get(pl.Series([1, -2, 4]), null_on_oob=False)263264265def test_array_get_null_on_oob() -> None:266s = pl.Series(267"a",268[[1, 2, 3, 4], [5, 6, None, None], [7, 8, 9, 10]],269dtype=pl.Array(pl.Int64, 4),270)271272# Test index literal.273out = s.arr.get(1, null_on_oob=True)274expected = pl.Series("a", [2, 6, 8], dtype=pl.Int64)275assert_series_equal(out, expected)276277# Null index literal.278out_df = s.to_frame().select(pl.col.a.arr.get(pl.lit(None), null_on_oob=True))279expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame()280assert_frame_equal(out_df, expected_df)281282# Out-of-bounds index literal.283out = s.arr.get(100, null_on_oob=True)284expected = pl.Series("a", [None, None, None], dtype=pl.Int64)285assert_series_equal(out, expected)286287# Negative index literal.288out = s.arr.get(-2, null_on_oob=True)289expected = pl.Series("a", [3, None, 9], dtype=pl.Int64)290assert_series_equal(out, expected)291292# Test index expr.293out = s.arr.get(pl.Series([1, -2, 100]), null_on_oob=True)294expected = pl.Series("a", [2, None, None], dtype=pl.Int64)295assert_series_equal(out, expected)296297# Test logical type.298s = pl.Series(299"a",300[301[datetime.date(1999, 1, 1), datetime.date(2000, 1, 1)],302[datetime.date(2001, 10, 1), None],303[None, None],304],305dtype=pl.Array(pl.Date, 2),306)307out = s.arr.get(pl.Series([1, -2, 4]), null_on_oob=True)308expected = pl.Series(309"a",310[datetime.date(2000, 1, 1), datetime.date(2001, 10, 1), None],311dtype=pl.Date,312)313assert_series_equal(out, expected)314315316def test_arr_first_last() -> None:317s = pl.Series(318"a",319[[1, 2, 3], [None, 5, 6], [None, None, None]],320dtype=pl.Array(pl.Int64, 3),321)322323first = s.arr.first()324expected_first = pl.Series(325"a",326[1, None, None],327dtype=pl.Int64,328)329assert_series_equal(first, expected_first)330331last = s.arr.last()332expected_last = pl.Series(333"a",334[3, 6, None],335dtype=pl.Int64,336)337assert_series_equal(last, expected_last)338339340@pytest.mark.parametrize(341("data", "set", "dtype"),342[343([1, 2], [[1, 2], [3, 4]], pl.Int64),344([True, False], [[True, False], [True, True]], pl.Boolean),345(["a", "b"], [["a", "b"], ["c", "d"]], pl.String),346([b"a", b"b"], [[b"a", b"b"], [b"c", b"d"]], pl.Binary),347(348[{"a": 1}, {"a": 2}],349[[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]],350pl.Struct([pl.Field("a", pl.Int64)]),351),352],353)354def test_is_in_array(data: list[Any], set: list[list[Any]], dtype: pl.DataType) -> None:355df = pl.DataFrame(356{"a": data, "arr": set},357schema={"a": dtype, "arr": pl.Array(dtype, 2)},358)359out = df.select(is_in=pl.col("a").is_in(pl.col("arr"))).to_series()360expected = pl.Series("is_in", [True, False])361assert_series_equal(out, expected)362363364def test_array_join() -> None:365df = pl.DataFrame(366{367"a": [["ab", "c", "d"], ["e", "f", "g"], [None, None, None], None],368"separator": ["&", None, "*", "_"],369},370schema={371"a": pl.Array(pl.String, 3),372"separator": pl.String,373},374)375out = df.select(pl.col("a").arr.join("-"))376assert out.to_dict(as_series=False) == {"a": ["ab-c-d", "e-f-g", "", None]}377out = df.select(pl.col("a").arr.join(pl.col("separator")))378assert out.to_dict(as_series=False) == {"a": ["ab&c&d", None, "", None]}379380# test ignore_nulls argument381df = pl.DataFrame(382{383"a": [384["a", None, "b", None],385None,386[None, None, None, None],387["c", "d", "e", "f"],388],389"separator": ["-", "&", " ", "@"],390},391schema={392"a": pl.Array(pl.String, 4),393"separator": pl.String,394},395)396# ignore nulls397out = df.select(pl.col("a").arr.join("-", ignore_nulls=True))398assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c-d-e-f"]}399out = df.select(pl.col("a").arr.join(pl.col("separator"), ignore_nulls=True))400assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c@d@e@f"]}401# propagate nulls402out = df.select(pl.col("a").arr.join("-", ignore_nulls=False))403assert out.to_dict(as_series=False) == {"a": [None, None, None, "c-d-e-f"]}404out = df.select(pl.col("a").arr.join(pl.col("separator"), ignore_nulls=False))405assert out.to_dict(as_series=False) == {"a": [None, None, None, "c@d@e@f"]}406407408def test_array_explode() -> None:409df = pl.DataFrame(410{411"str": [["a", "b"], ["c", None], None],412"nested": [[[1, 2], [3]], [[], [4, None]], None],413"logical": [414[datetime.date(1998, 1, 1), datetime.date(2000, 10, 1)],415[datetime.date(2024, 1, 1), None],416None,417],418},419schema={420"str": pl.Array(pl.String, 2),421"nested": pl.Array(pl.List(pl.Int64), 2),422"logical": pl.Array(pl.Date, 2),423},424)425out = df.select(pl.all().arr.explode())426expected = pl.DataFrame(427{428"str": ["a", "b", "c", None, None],429"nested": [[1, 2], [3], [], [4, None], None],430"logical": [431datetime.date(1998, 1, 1),432datetime.date(2000, 10, 1),433datetime.date(2024, 1, 1),434None,435None,436],437}438)439assert_frame_equal(out, expected)440441# test no-null fast path442s = pl.Series(443[444[datetime.date(1998, 1, 1), datetime.date(1999, 1, 3)],445[datetime.date(2000, 1, 1), datetime.date(2023, 10, 1)],446],447dtype=pl.Array(pl.Date, 2),448)449out_s = s.arr.explode()450expected_s = pl.Series(451[452datetime.date(1998, 1, 1),453datetime.date(1999, 1, 3),454datetime.date(2000, 1, 1),455datetime.date(2023, 10, 1),456],457dtype=pl.Date,458)459assert_series_equal(out_s, expected_s)460461462@pytest.mark.parametrize(463("arr", "data", "expected", "dtype"),464[465([[1, 2], [3, None], None], 1, [1, 0, None], pl.Int64),466([[True, False], [True, None], None], True, [1, 1, None], pl.Boolean),467([["a", "b"], ["c", None], None], "a", [1, 0, None], pl.String),468([[b"a", b"b"], [b"c", None], None], b"a", [1, 0, None], pl.Binary),469],470)471def test_array_count_matches(472arr: list[list[Any] | None], data: Any, expected: list[Any], dtype: pl.DataType473) -> None:474df = pl.DataFrame({"arr": arr}, schema={"arr": pl.Array(dtype, 2)})475out = df.select(count_matches=pl.col("arr").arr.count_matches(data))476assert out.to_dict(as_series=False) == {"count_matches": expected}477478479def test_array_count_matches_wildcard_expansion() -> None:480df = pl.DataFrame(481{"a": [[1, 2]], "b": [[3, 4]]},482schema={"a": pl.Array(pl.Int64, 2), "b": pl.Array(pl.Int64, 2)},483)484assert df.select(pl.all().arr.count_matches(3)).to_dict(as_series=False) == {485"a": [0],486"b": [1],487}488489490def test_array_to_struct() -> None:491df = pl.DataFrame(492{"a": [[1, 2, 3], [4, 5, None]]}, schema={"a": pl.Array(pl.Int8, 3)}493)494assert df.select([pl.col("a").arr.to_struct()]).to_series().to_list() == [495{"field_0": 1, "field_1": 2, "field_2": 3},496{"field_0": 4, "field_1": 5, "field_2": None},497]498499df = pl.DataFrame(500{"a": [[1, 2, None], [1, 2, 3]]}, schema={"a": pl.Array(pl.Int8, 3)}501)502assert df.select(503pl.col("a").arr.to_struct(fields=lambda idx: f"col_name_{idx}")504).to_series().to_list() == [505{"col_name_0": 1, "col_name_1": 2, "col_name_2": None},506{"col_name_0": 1, "col_name_1": 2, "col_name_2": 3},507]508509assert df.lazy().select(pl.col("a").arr.to_struct()).unnest(510"a"511).sum().collect().columns == ["field_0", "field_1", "field_2"]512513514def test_array_shift() -> None:515df = pl.DataFrame(516{"a": [[1, 2, 3], None, [4, 5, 6], [7, 8, 9]], "n": [None, 1, 1, -2]},517schema={"a": pl.Array(pl.Int64, 3), "n": pl.Int64},518)519520out = df.select(521lit=pl.col("a").arr.shift(1), expr=pl.col("a").arr.shift(pl.col("n"))522)523expected = pl.DataFrame(524{525"lit": [[None, 1, 2], None, [None, 4, 5], [None, 7, 8]],526"expr": [None, None, [None, 4, 5], [9, None, None]],527},528schema={"lit": pl.Array(pl.Int64, 3), "expr": pl.Array(pl.Int64, 3)},529)530assert_frame_equal(out, expected)531532533def test_array_n_unique() -> None:534df = pl.DataFrame(535{536"a": [[1, 1, 2], [3, 3, 3], [None, None, None], None],537},538schema={"a": pl.Array(pl.Int64, 3)},539)540541out = df.select(n_unique=pl.col("a").arr.n_unique())542expected = pl.DataFrame(543{"n_unique": [2, 1, 1, None]}, schema={"n_unique": pl.UInt32}544)545assert_frame_equal(out, expected)546547548def test_explode_19049() -> None:549df = pl.DataFrame({"a": [[1, 2, 3]]}, schema={"a": pl.Array(pl.Int64, 3)})550result_df = df.select(pl.col.a.arr.explode())551expected_df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64})552assert_frame_equal(result_df, expected_df)553554df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64})555with pytest.raises(InvalidOperationError, match="expected Array type, got: i64"):556df.select(pl.col.a.arr.explode())557558559def test_array_join_unequal_lengths_22018() -> None:560df = pl.DataFrame(561[562pl.Series(563"a",564[565["a", "b", "d"],566["ya", "x", "y"],567["ya", "x", "y"],568],569pl.Array(pl.String, 3),570),571]572)573with pytest.raises(pl.exceptions.ShapeError):574df.select(pl.col.a.arr.join(pl.Series([",", "-"])))575576577def test_array_shift_unequal_lengths_22018() -> None:578with pytest.raises(pl.exceptions.ShapeError):579pl.Series(580"a",581[582["a", "b", "d"],583["a", "b", "d"],584["a", "b", "d"],585],586pl.Array(pl.String, 3),587).arr.shift(pl.Series([1, 2]))588589590def test_array_shift_self_broadcast_22124() -> None:591assert_series_equal(592pl.Series(593"a",594[595["a", "b", "d"],596],597pl.Array(pl.String, 3),598).arr.shift(pl.Series([1, 2])),599pl.Series(600"a",601[602[None, "a", "b"],603[None, None, "a"],604],605pl.Array(pl.String, 3),606),607)608609610def test_arr_contains() -> None:611s = pl.Series([[1, 2, None], [None, None, None], None], dtype=pl.Array(pl.Int64, 3))612613assert_series_equal(614s.arr.contains(None, nulls_equal=False),615pl.Series([None, None, None], dtype=pl.Boolean),616)617assert_series_equal(618s.arr.contains(None, nulls_equal=True),619pl.Series([True, True, None], dtype=pl.Boolean),620)621assert_series_equal(622s.arr.contains(1, nulls_equal=False),623pl.Series([True, False, None], dtype=pl.Boolean),624)625assert_series_equal(626s.arr.contains(1, nulls_equal=True),627pl.Series([True, False, None], dtype=pl.Boolean),628)629630631