Path: blob/main/py-polars/tests/unit/operations/test_interpolate.py
6939 views
from __future__ import annotations12from datetime import date, datetime, time, timedelta3from typing import TYPE_CHECKING, Any45import pytest67import polars as pl8from polars.testing import assert_frame_equal9from tests.unit.conftest import NUMERIC_DTYPES1011if TYPE_CHECKING:12from polars._typing import InterpolationMethod, PolarsDataType, PolarsTemporalType1314from zoneinfo import ZoneInfo151617@pytest.mark.parametrize(18("input_dtype", "output_dtype"),19[20(pl.Int8, pl.Float64),21(pl.Int16, pl.Float64),22(pl.Int32, pl.Float64),23(pl.Int64, pl.Float64),24(pl.Int128, pl.Float64),25(pl.UInt8, pl.Float64),26(pl.UInt16, pl.Float64),27(pl.UInt32, pl.Float64),28(pl.UInt64, pl.Float64),29(pl.Float32, pl.Float32),30(pl.Float64, pl.Float64),31],32)33def test_interpolate_linear(34input_dtype: PolarsDataType, output_dtype: PolarsDataType35) -> None:36df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})37result = df.with_columns(pl.all().interpolate(method="linear"))38assert result.collect_schema()["a"] == output_dtype39expected = pl.DataFrame(40{"a": [1.0, 1.5, 2.0, 2.5, 3.0]}, schema={"a": output_dtype}41)42assert_frame_equal(result.collect(), expected)434445@pytest.mark.parametrize(46("input", "input_dtype", "output"),47[48(49[date(2020, 1, 1), None, date(2020, 1, 2)],50pl.Date,51[date(2020, 1, 1), date(2020, 1, 1), date(2020, 1, 2)],52),53(54[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],55pl.Datetime("ms"),56[datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)],57),58(59[60datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),61None,62datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),63],64pl.Datetime("us", "Asia/Kathmandu"),65[66datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),67datetime(2020, 1, 1, 12, tzinfo=ZoneInfo("Asia/Kathmandu")),68datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),69],70),71([time(1), None, time(2)], pl.Time, [time(1), time(1, 30), time(2)]),72(73[timedelta(1), None, timedelta(2)],74pl.Duration("ms"),75[timedelta(1), timedelta(1, hours=12), timedelta(2)],76),77],78)79def test_interpolate_temporal_linear(80input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]81) -> None:82df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})83result = df.with_columns(pl.all().interpolate(method="linear"))84assert result.collect_schema()["a"] == input_dtype85expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})86assert_frame_equal(result.collect(), expected)878889@pytest.mark.parametrize("input_dtype", NUMERIC_DTYPES)90def test_interpolate_nearest(input_dtype: PolarsDataType) -> None:91df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})92result = df.with_columns(pl.all().interpolate(method="nearest"))93assert result.collect_schema()["a"] == input_dtype94expected = pl.DataFrame({"a": [1, 2, 2, 3, 3]}, schema={"a": input_dtype})95assert_frame_equal(result.collect(), expected)969798@pytest.mark.parametrize(99("input", "input_dtype", "output"),100[101(102[date(2020, 1, 1), None, date(2020, 1, 2)],103pl.Date,104[date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)],105),106(107[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],108pl.Datetime("ms"),109[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)],110),111(112[113datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),114None,115datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),116],117pl.Datetime("us", "Asia/Kathmandu"),118[119datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),120datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),121datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),122],123),124([time(1), None, time(2)], pl.Time, [time(1), time(2), time(2)]),125(126[timedelta(1), None, timedelta(2)],127pl.Duration("ms"),128[timedelta(1), timedelta(2), timedelta(2)],129),130],131)132def test_interpolate_temporal_nearest(133input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]134) -> None:135df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})136result = df.with_columns(pl.all().interpolate(method="nearest"))137assert result.collect_schema()["a"] == input_dtype138expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})139assert_frame_equal(result.collect(), expected)140141142@pytest.mark.parametrize(143("input", "scale", "method", "output"),144# note the lack of rounding (1.66 vs 1.67)145[146([1.0, None, 3.0], 2, "linear", [1.0, 2.0, 3.0]),147([1.0, None, None, 2.0], 2, "linear", [1.0, 1.33, 1.66, 2.0]),148([1.0, None, 3.0], 2, "nearest", [1.0, 3.0, 3.0]),149([1.0, None, None, 2.0], 2, "nearest", [1.0, 1.0, 2.0, 2.0]),150],151)152def test_interpolate_decimal_22475(153input: list[Any], scale: int, method: InterpolationMethod, output: list[Any]154) -> None:155df = pl.DataFrame({"data": input})156df_decimal = df.with_columns(pl.col("data").cast(pl.Decimal(scale=scale)))157out = df_decimal.with_columns(pl.col("data").interpolate(method=method))158expected = pl.DataFrame({"data": output}).with_columns(159pl.col("data").cast(pl.Decimal(scale=2))160)161assert_frame_equal(out, expected)162163q = df_decimal.lazy().with_columns(pl.col("data").interpolate(method=method))164assert q.collect_schema() == q.collect().schema165166167