Path: blob/main/py-polars/tests/unit/functions/test_repeat.py
6939 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from typing import TYPE_CHECKING, Any45import pytest67import polars as pl8from polars.exceptions import ComputeError, SchemaError, ShapeError9from polars.testing import assert_frame_equal, assert_series_equal1011if TYPE_CHECKING:12from polars._typing import PolarsDataType131415@pytest.mark.parametrize(16("value", "n", "dtype", "expected_dtype"),17[18(2**31, 5, None, pl.Int64),19(2**31 - 1, 5, None, pl.Int32),20(-(2**31) - 1, 3, None, pl.Int64),21(-(2**31), 3, None, pl.Int32),22("foo", 2, None, pl.String),23(1.0, 5, None, pl.Float64),24(True, 4, None, pl.Boolean),25(None, 7, None, pl.Null),26(0, 0, None, pl.Int32),27(datetime(2023, 2, 2), 3, None, pl.Datetime),28(date(2023, 2, 2), 3, None, pl.Date),29(time(10, 15), 1, None, pl.Time),30(timedelta(hours=3), 10, None, pl.Duration),31(8, 2, pl.UInt8, pl.UInt8),32(date(2023, 2, 2), 3, pl.Datetime, pl.Datetime),33(7.5, 5, pl.UInt16, pl.UInt16),34([1, 2, 3], 2, pl.List(pl.Int64), pl.List(pl.Int64)),35(b"ab12", 3, pl.Binary, pl.Binary),36],37)38def test_repeat(39value: Any,40n: int,41dtype: PolarsDataType,42expected_dtype: PolarsDataType,43) -> None:44expected = pl.Series("repeat", [value] * n).cast(expected_dtype)4546result_eager = pl.repeat(value, n=n, dtype=dtype, eager=True)47assert_series_equal(result_eager, expected)4849result_lazy = pl.select(pl.repeat(value, n=n, dtype=dtype, eager=False)).to_series()50assert_series_equal(result_lazy, expected)515253def test_repeat_expr_input_eager() -> None:54result = pl.select(pl.repeat(1, n=pl.lit(3), eager=True)).to_series()55expected = pl.Series("repeat", [1, 1, 1], dtype=pl.Int32)56assert_series_equal(result, expected)575859def test_repeat_expr_input_lazy() -> None:60df = pl.DataFrame({"a": [3, 2, 1]})61result = df.select(pl.repeat(1, n=pl.col("a").first())).to_series()62expected = pl.Series("repeat", [1, 1, 1], dtype=pl.Int32)63assert_series_equal(result, expected)6465df = pl.DataFrame({"a": [3, 2, 1]})66assert df.select(pl.repeat(pl.sum("a"), n=2)).to_series().to_list() == [6, 6]676869def test_repeat_n_zero() -> None:70assert pl.repeat(1, n=0, eager=True).len() == 0717273@pytest.mark.parametrize(74"n",75[1.5, 2.0, date(1971, 1, 2), "hello"],76)77def test_repeat_n_non_integer(n: Any) -> None:78with pytest.raises(SchemaError, match="expected expression of dtype 'integer'"):79pl.repeat(1, n=pl.lit(n), eager=True)808182def test_repeat_n_empty() -> None:83df = pl.DataFrame(schema={"a": pl.Int32})84with pytest.raises(ShapeError, match="'n' must be a scalar value"):85df.select(pl.repeat(1, n=pl.col("a")))868788def test_repeat_n_negative() -> None:89with pytest.raises(ComputeError, match="could not parse value '-1' as a size"):90pl.repeat(1, n=-1, eager=True)919293@pytest.mark.parametrize(94("n", "value", "dtype"),95[96(2, 1, pl.UInt32),97(0, 1, pl.Int16),98(3, 1, pl.Float32),99(1, "1", pl.Utf8),100(2, ["1"], pl.List(pl.Utf8)),101(4, True, pl.Boolean),102(2, [True], pl.List(pl.Boolean)),103(2, [1], pl.Array(pl.Int16, shape=1)),104(2, [1, 1, 1], pl.Array(pl.Int8, shape=3)),105(1, [1], pl.List(pl.UInt32)),106],107)108def test_ones(109n: int,110value: Any,111dtype: PolarsDataType,112) -> None:113expected = pl.Series("ones", [value] * n, dtype=dtype)114115result_eager = pl.ones(n=n, dtype=dtype, eager=True)116assert_series_equal(result_eager, expected)117118result_lazy = pl.select(pl.ones(n=n, dtype=dtype, eager=False)).to_series()119assert_series_equal(result_lazy, expected)120121122@pytest.mark.parametrize(123("n", "value", "dtype"),124[125(2, 0, pl.UInt8),126(0, 0, pl.Int32),127(3, 0, pl.Float32),128(1, "0", pl.Utf8),129(2, ["0"], pl.List(pl.Utf8)),130(4, False, pl.Boolean),131(2, [False], pl.List(pl.Boolean)),132(3, [0], pl.Array(pl.UInt32, shape=1)),133(2, [0, 0, 0], pl.Array(pl.UInt32, shape=3)),134(1, [0], pl.List(pl.UInt32)),135],136)137def test_zeros(138n: int,139value: Any,140dtype: PolarsDataType,141) -> None:142expected = pl.Series("zeros", [value] * n, dtype=dtype)143144result_eager = pl.zeros(n=n, dtype=dtype, eager=True)145assert_series_equal(result_eager, expected)146147result_lazy = pl.select(pl.zeros(n=n, dtype=dtype, eager=False)).to_series()148assert_series_equal(result_lazy, expected)149150151def test_ones_zeros_misc() -> None:152# check we default to f64 if dtype is unspecified153s_ones = pl.ones(n=2, eager=True)154s_zeros = pl.zeros(n=2, eager=True)155156assert s_ones.dtype == s_zeros.dtype == pl.Float64157158# confirm that we raise a suitable error if dtype is invalid159with pytest.raises(TypeError, match="invalid dtype for `ones`"):160pl.ones(n=2, dtype=pl.Struct({"x": pl.Date, "y": pl.Duration}), eager=True)161162with pytest.raises(TypeError, match="invalid dtype for `zeros`"):163pl.zeros(n=2, dtype=pl.Struct({"x": pl.Date, "y": pl.Duration}), eager=True)164165166def test_repeat_by_logical_dtype() -> None:167df = pl.DataFrame(168{169"repeat": [1, 2, 3],170"date": [date(2021, 1, 1)] * 3,171"cat": ["a", "b", "c"],172},173schema={"repeat": pl.Int32, "date": pl.Date, "cat": pl.Categorical},174)175out = df.select(176pl.col("date").repeat_by("repeat"), pl.col("cat").repeat_by("repeat")177)178179expected_df = pl.DataFrame(180{181"date": [182[date(2021, 1, 1)],183[date(2021, 1, 1), date(2021, 1, 1)],184[date(2021, 1, 1), date(2021, 1, 1), date(2021, 1, 1)],185],186"cat": [["a"], ["b", "b"], ["c", "c", "c"]],187},188schema={"date": pl.List(pl.Date), "cat": pl.List(pl.Categorical)},189)190191assert_frame_equal(out, expected_df)192193194def test_repeat_by_list() -> None:195df = pl.DataFrame(196{197"repeat": [1, 2, 3, None],198"value": [None, [1, 2, 3], [4, None], [1, 2]],199},200schema={"repeat": pl.UInt32, "value": pl.List(pl.UInt8)},201)202out = df.select(pl.col("value").repeat_by("repeat"))203204expected_df = pl.DataFrame(205{206"value": [207[None],208[[1, 2, 3], [1, 2, 3]],209[[4, None], [4, None], [4, None]],210None,211],212},213schema={"value": pl.List(pl.List(pl.UInt8))},214)215216assert_frame_equal(out, expected_df)217218219def test_repeat_by_nested_list() -> None:220df = pl.DataFrame(221{222"repeat": [1, 2, 3],223"value": [None, [[1], [2, 2]], [[3, 3], None, [4, None]]],224},225schema={"repeat": pl.UInt32, "value": pl.List(pl.List(pl.Int16))},226)227out = df.select(pl.col("value").repeat_by("repeat"))228229expected_df = pl.DataFrame(230{231"value": [232[None],233[[[1], [2, 2]], [[1], [2, 2]]],234[235[[3, 3], None, [4, None]],236[[3, 3], None, [4, None]],237[[3, 3], None, [4, None]],238],239],240},241schema={"value": pl.List(pl.List(pl.List(pl.Int16)))},242)243244assert_frame_equal(out, expected_df)245246247def test_repeat_by_struct() -> None:248df = pl.DataFrame(249{250"repeat": [1, 2, 3],251"value": [None, {"a": 1, "b": 2}, {"a": 3, "b": None}],252},253schema={"repeat": pl.UInt32, "value": pl.Struct({"a": pl.Int8, "b": pl.Int32})},254)255out = df.select(pl.col("value").repeat_by("repeat"))256257expected_df = pl.DataFrame(258{259"value": [260[None],261[{"a": 1, "b": 2}, {"a": 1, "b": 2}],262[{"a": 3, "b": None}, {"a": 3, "b": None}, {"a": 3, "b": None}],263],264},265schema={"value": pl.List(pl.Struct({"a": pl.Int8, "b": pl.Int32}))},266)267268assert_frame_equal(out, expected_df)269270271def test_repeat_by_nested_struct() -> None:272df = pl.DataFrame(273{274"repeat": [1, 2, 3],275"value": [276None,277{"a": {"x": 1, "y": 1}, "b": 2},278{"a": {"x": None, "y": 3}, "b": None},279],280},281schema={282"repeat": pl.UInt32,283"value": pl.Struct(284{"a": pl.Struct({"x": pl.Int64, "y": pl.Int128}), "b": pl.Int32}285),286},287)288out = df.select(pl.col("value").repeat_by("repeat"))289290expected_df = pl.DataFrame(291{292"value": [293[None],294[{"a": {"x": 1, "y": 1}, "b": 2}, {"a": {"x": 1, "y": 1}, "b": 2}],295[296{"a": {"x": None, "y": 3}, "b": None},297{"a": {"x": None, "y": 3}, "b": None},298{"a": {"x": None, "y": 3}, "b": None},299],300],301},302schema={303"value": pl.List(304pl.Struct(305{"a": pl.Struct({"x": pl.Int64, "y": pl.Int128}), "b": pl.Int32}306)307)308},309)310311assert_frame_equal(out, expected_df)312313314def test_repeat_by_struct_in_list() -> None:315df = pl.DataFrame(316{317"repeat": [1, 2, 3],318"value": [319None,320[{"a": "foo", "b": "A"}, None],321[{"a": None, "b": "B"}, {"a": "test", "b": "B"}],322],323},324schema={325"repeat": pl.UInt32,326"value": pl.List(pl.Struct({"a": pl.String, "b": pl.Enum(["A", "B"])})),327},328)329out = df.select(pl.col("value").repeat_by("repeat"))330331expected_df = pl.DataFrame(332{333"value": [334[None],335[[{"a": "foo", "b": "A"}, None], [{"a": "foo", "b": "A"}, None]],336[337[{"a": None, "b": "B"}, {"a": "test", "b": "B"}],338[{"a": None, "b": "B"}, {"a": "test", "b": "B"}],339[{"a": None, "b": "B"}, {"a": "test", "b": "B"}],340],341],342},343schema={344"value": pl.List(345pl.List(pl.Struct({"a": pl.String, "b": pl.Enum(["A", "B"])}))346)347},348)349350assert_frame_equal(out, expected_df)351352353def test_repeat_by_list_in_struct() -> None:354df = pl.DataFrame(355{356"repeat": [1, 2, 3],357"value": [358None,359{"a": [1, 2, 3], "b": ["x", "y", None]},360{"a": [None, 5, 6], "b": None},361],362},363schema={364"repeat": pl.UInt32,365"value": pl.Struct({"a": pl.List(pl.Int8), "b": pl.List(pl.String)}),366},367)368out = df.select(pl.col("value").repeat_by("repeat"))369370expected_df = pl.DataFrame(371{372"value": [373[None],374[375{"a": [1, 2, 3], "b": ["x", "y", None]},376{"a": [1, 2, 3], "b": ["x", "y", None]},377],378[379{"a": [None, 5, 6], "b": None},380{"a": [None, 5, 6], "b": None},381{"a": [None, 5, 6], "b": None},382],383],384},385schema={386"value": pl.List(387pl.Struct({"a": pl.List(pl.Int8), "b": pl.List(pl.String)})388)389},390)391392assert_frame_equal(out, expected_df)393394395@pytest.mark.parametrize(396("data", "expected_data"),397[398(["a", "b", None], [["a", "a"], None, [None, None, None]]),399([1, 2, None], [[1, 1], None, [None, None, None]]),400([1.1, 2.2, None], [[1.1, 1.1], None, [None, None, None]]),401([True, False, None], [[True, True], None, [None, None, None]]),402],403)404def test_repeat_by_none_13053(data: list[Any], expected_data: list[list[Any]]) -> None:405df = pl.DataFrame({"x": data, "by": [2, None, 3]})406res = df.select(repeat=pl.col("x").repeat_by("by"))407expected = pl.Series("repeat", expected_data)408assert_series_equal(res.to_series(), expected)409410411def test_repeat_by_literal_none_20268() -> None:412df = pl.DataFrame({"x": ["a", "b"]})413expected = pl.Series("repeat", [None, None], dtype=pl.List(pl.String))414415res = df.select(repeat=pl.col("x").repeat_by(pl.lit(None)))416assert_series_equal(res.to_series(), expected)417418res = df.select(repeat=pl.col("x").repeat_by(None)) # type: ignore[arg-type]419assert_series_equal(res.to_series(), expected)420421422@pytest.mark.parametrize("value", [pl.Series([]), pl.Series([1, 2])])423def test_repeat_nonscalar_value(value: pl.Series) -> None:424with pytest.raises(ShapeError, match="'value' must be a scalar value"):425pl.select(pl.repeat(pl.Series(value), n=1))426427428@pytest.mark.parametrize("n", [[], [1, 2]])429def test_repeat_nonscalar_n(n: list[int]) -> None:430df = pl.DataFrame({"n": n})431with pytest.raises(ShapeError, match="'n' must be a scalar value"):432df.select(pl.repeat("a", pl.col("n")))433434435def test_repeat_value_first() -> None:436df = pl.DataFrame({"a": ["a", "b", "c"], "n": [4, 5, 6]})437result = df.select(rep=pl.repeat(pl.col("a").first(), n=pl.col("n").first()))438expected = pl.DataFrame({"rep": ["a", "a", "a", "a"]})439assert_frame_equal(result, expected)440441442def test_repeat_by_arr() -> None:443assert_series_equal(444pl.Series([["a", "b"], ["a", "c"]], dtype=pl.Array(pl.String, 2)).repeat_by(2),445pl.Series(446[[["a", "b"], ["a", "b"]], [["a", "c"], ["a", "c"]]],447dtype=pl.List(pl.Array(pl.String, 2)),448),449)450451452def test_repeat_by_null() -> None:453assert_series_equal(454pl.Series([None, None], dtype=pl.Null).repeat_by(2),455pl.Series([[None, None], [None, None]], dtype=pl.List(pl.Null)),456)457458459