Path: blob/main/py-polars/tests/unit/operations/test_extend_constant.py
6939 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from typing import TYPE_CHECKING, Any45import pytest67import polars as pl8from polars.exceptions import ShapeError9from polars.testing import assert_frame_equal, assert_series_equal1011if TYPE_CHECKING:12from polars._typing import PolarsDataType131415@pytest.mark.parametrize(16("const", "dtype"),17[18(1, pl.Int8),19(4, pl.UInt32),20(4.5, pl.Float32),21(None, pl.Float64),22("白鵬翔", pl.String),23(date.today(), pl.Date),24(datetime.now(), pl.Datetime("ns")),25(time(23, 59, 59), pl.Time),26(timedelta(hours=7, seconds=123), pl.Duration("ms")),27],28)29def test_extend_constant(const: Any, dtype: PolarsDataType) -> None:30df = pl.DataFrame({"a": pl.Series("s", [None], dtype=dtype)})3132expected_df = pl.DataFrame(33{"a": pl.Series("s", [None, const, const, const], dtype=dtype)}34)3536assert_frame_equal(df.select(pl.col("a").extend_constant(const, 3)), expected_df)3738s = pl.Series("s", [None], dtype=dtype)39expected = pl.Series("s", [None, const, const, const], dtype=dtype)40assert_series_equal(s.extend_constant(const, 3), expected)4142# test n expr43expected = pl.Series("s", [None, const, const], dtype=dtype)44assert_series_equal(s.extend_constant(const, pl.lit(2)), expected)4546# test value expr47expected = pl.Series("s", [None, const, const, const], dtype=dtype)48assert_series_equal(s.extend_constant(pl.lit(const, dtype=dtype), 3), expected)495051@pytest.mark.parametrize(52("const", "dtype"),53[54(1, pl.Int8),55(4, pl.UInt32),56(4.5, pl.Float32),57(None, pl.Float64),58("白鵬翔", pl.String),59(date.today(), pl.Date),60(datetime.now(), pl.Datetime("ns")),61(time(23, 59, 59), pl.Time),62(timedelta(hours=7, seconds=123), pl.Duration("ms")),63],64)65def test_extend_constant_arr(const: Any, dtype: PolarsDataType) -> None:66"""67Test extend_constant in pl.List array.6869NOTE: This function currently fails when the Series is a list with a single [None]70value. Hence, this function does not begin with [[None]], but [[const]].71"""72s = pl.Series("s", [[const]], dtype=pl.List(dtype))7374expected = pl.Series("s", [[const, const, const, const]], dtype=pl.List(dtype))7576assert_series_equal(s.list.eval(pl.element().extend_constant(const, 3)), expected)777879def test_extend_by_not_uint_expr() -> None:80s = pl.Series("s", [1])81with pytest.raises(ShapeError, match="'value' must be a scalar value"):82s.extend_constant(pl.Series([2, 3]), 3)83with pytest.raises(ShapeError, match="'n' must be a scalar value"):84s.extend_constant(2, pl.Series([3, 4]))858687