Path: blob/main/py-polars/tests/unit/operations/test_interpolate.py
8422 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from typing import TYPE_CHECKING, Any45import pytest67import polars as pl8from polars.testing import assert_frame_equal9from tests.unit.conftest import NUMERIC_DTYPES1011if TYPE_CHECKING:12from polars._typing import InterpolationMethod, PolarsDataType, PolarsTemporalType1314from zoneinfo import ZoneInfo151617@pytest.mark.parametrize(18("input_dtype", "output_dtype"),19[20(pl.Int8, pl.Float64),21(pl.Int16, pl.Float64),22(pl.Int32, pl.Float64),23(pl.Int64, pl.Float64),24(pl.Int128, pl.Float64),25(pl.UInt8, pl.Float64),26(pl.UInt16, pl.Float64),27(pl.UInt32, pl.Float64),28(pl.UInt64, pl.Float64),29(pl.UInt128, pl.Float64),30(pl.Float32, pl.Float32),31(pl.Float64, pl.Float64),32],33)34def test_interpolate_linear(35input_dtype: PolarsDataType, output_dtype: PolarsDataType36) -> None:37df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})38result = df.with_columns(pl.all().interpolate(method="linear"))39assert result.collect_schema()["a"] == output_dtype40expected = pl.DataFrame(41{"a": [1.0, 1.5, 2.0, 2.5, 3.0]}, schema={"a": output_dtype}42)43assert_frame_equal(result.collect(), expected)444546@pytest.mark.parametrize(47("input", "input_dtype", "output"),48[49(50[date(2020, 1, 1), None, date(2020, 1, 2)],51pl.Date,52[date(2020, 1, 1), date(2020, 1, 1), date(2020, 1, 2)],53),54(55[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],56pl.Datetime("ms"),57[datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)],58),59(60[61datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),62None,63datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),64],65pl.Datetime("us", "Asia/Kathmandu"),66[67datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),68datetime(2020, 1, 1, 12, tzinfo=ZoneInfo("Asia/Kathmandu")),69datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),70],71),72([time(1), None, time(2)], pl.Time, [time(1), time(1, 30), time(2)]),73(74[timedelta(1), None, timedelta(2)],75pl.Duration("ms"),76[timedelta(1), timedelta(1, hours=12), timedelta(2)],77),78],79)80def test_interpolate_temporal_linear(81input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]82) -> None:83df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})84result = df.with_columns(pl.all().interpolate(method="linear"))85assert result.collect_schema()["a"] == input_dtype86expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})87assert_frame_equal(result.collect(), expected)888990@pytest.mark.parametrize("input_dtype", NUMERIC_DTYPES)91def test_interpolate_nearest(input_dtype: PolarsDataType) -> None:92df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})93result = df.with_columns(pl.all().interpolate(method="nearest"))94assert result.collect_schema()["a"] == input_dtype95expected = pl.DataFrame({"a": [1, 2, 2, 3, 3]}, schema={"a": input_dtype})96assert_frame_equal(result.collect(), expected)979899@pytest.mark.parametrize(100("input", "input_dtype", "output"),101[102(103[date(2020, 1, 1), None, date(2020, 1, 2)],104pl.Date,105[date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)],106),107(108[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],109pl.Datetime("ms"),110[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)],111),112(113[114datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),115None,116datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),117],118pl.Datetime("us", "Asia/Kathmandu"),119[120datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),121datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),122datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),123],124),125([time(1), None, time(2)], pl.Time, [time(1), time(2), time(2)]),126(127[timedelta(1), None, timedelta(2)],128pl.Duration("ms"),129[timedelta(1), timedelta(2), timedelta(2)],130),131],132)133def test_interpolate_temporal_nearest(134input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]135) -> None:136df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})137result = df.with_columns(pl.all().interpolate(method="nearest"))138assert result.collect_schema()["a"] == input_dtype139expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})140assert_frame_equal(result.collect(), expected)141142143@pytest.mark.parametrize(144("input", "scale", "method", "output"),145# note the lack of rounding (1.66 vs 1.67)146[147([1.0, None, 3.0], 2, "linear", [1.0, 2.0, 3.0]),148([1.0, None, None, 2.0], 2, "linear", [1.0, 1.33, 1.66, 2.0]),149([1.0, None, 3.0], 2, "nearest", [1.0, 3.0, 3.0]),150([1.0, None, None, 2.0], 2, "nearest", [1.0, 1.0, 2.0, 2.0]),151],152)153def test_interpolate_decimal_22475(154input: list[Any], scale: int, method: InterpolationMethod, output: list[Any]155) -> None:156df = pl.DataFrame({"data": input})157df_decimal = df.with_columns(pl.col("data").cast(pl.Decimal(scale=scale)))158out = df_decimal.with_columns(pl.col("data").interpolate(method=method))159expected = pl.DataFrame({"data": output}).with_columns(160pl.col("data").cast(pl.Decimal(scale=2))161)162assert_frame_equal(out, expected)163164q = df_decimal.lazy().with_columns(pl.col("data").interpolate(method=method))165assert q.collect_schema() == q.collect().schema166167168