Path: blob/main/py-polars/tests/unit/operations/test_interpolate_by.py
6939 views
from __future__ import annotations12from datetime import date3from typing import TYPE_CHECKING45import hypothesis.strategies as st6import numpy as np7import pytest8from hypothesis import assume, given910import polars as pl11from polars.exceptions import InvalidOperationError12from polars.testing import assert_frame_equal, assert_series_equal13from polars.testing.parametric import column, dataframes1415if TYPE_CHECKING:16from polars._typing import PolarsDataType171819@pytest.mark.parametrize(20"times_dtype",21[22pl.Datetime("ms"),23pl.Datetime("us", "Asia/Kathmandu"),24pl.Datetime("ns"),25pl.Date,26pl.Int64,27pl.Int32,28pl.UInt64,29pl.UInt32,30pl.Float32,31pl.Float64,32],33)34@pytest.mark.parametrize(35"values_dtype",36[37pl.Float64,38pl.Float32,39pl.Int64,40pl.Int32,41pl.UInt64,42pl.UInt32,43],44)45def test_interpolate_by(46values_dtype: PolarsDataType, times_dtype: PolarsDataType47) -> None:48df = pl.DataFrame(49{50"times": [511,523,5310,5411,5512,5616,5721,5830,59],60"values": [1, None, None, 5, None, None, None, 6],61},62schema={"times": times_dtype, "values": values_dtype},63)64result = df.select(pl.col("values").interpolate_by("times"))65expected = pl.DataFrame(66{67"values": [681.0,691.7999999999999998,704.6,715.0,725.052631578947368,735.2631578947368425,745.526315789473684,756.0,76]77}78)79if values_dtype == pl.Float32:80expected = expected.select(pl.col("values").cast(pl.Float32))81assert_frame_equal(result, expected)82result = (83df.sort("times", descending=True)84.with_columns(pl.col("values").interpolate_by("times"))85.sort("times")86.drop("times")87)88assert_frame_equal(result, expected)899091def test_interpolate_by_leading_nulls() -> None:92df = pl.DataFrame(93{94"times": [95date(2020, 1, 1),96date(2020, 1, 1),97date(2020, 1, 1),98date(2020, 1, 1),99date(2020, 1, 3),100date(2020, 1, 10),101date(2020, 1, 11),102],103"values": [None, None, None, 1, None, None, 5],104}105)106result = df.select(pl.col("values").interpolate_by("times"))107expected = pl.DataFrame({"values": [None, None, None, 1.0, 1.8, 4.6, 5.0]})108assert_frame_equal(result, expected)109result = (110df.sort("times", maintain_order=True, descending=True)111.with_columns(pl.col("values").interpolate_by("times"))112.sort("times", maintain_order=True)113.drop("times")114)115assert_frame_equal(result, expected, check_exact=False)116117118@pytest.mark.parametrize("dataset", ["floats", "dates"])119def test_interpolate_by_trailing_nulls(dataset: str) -> None:120input_data = {121"dates": pl.DataFrame(122{123"times": [124date(2020, 1, 1),125date(2020, 1, 3),126date(2020, 1, 10),127date(2020, 1, 11),128date(2020, 1, 12),129date(2020, 1, 13),130],131"values": [1, None, None, 5, None, None],132}133),134"floats": pl.DataFrame(135{136"times": [0.2, 0.4, 0.5, 0.6, 0.9, 1.1],137"values": [1, None, None, 5, None, None],138}139),140}141142expected_data = {143"dates": pl.DataFrame(144{"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}145),146"floats": pl.DataFrame({"values": [1.0, 3.0, 4.0, 5.0, None, None]}),147}148149df = input_data[dataset]150expected = expected_data[dataset]151152result = df.select(pl.col("values").interpolate_by("times"))153154assert_frame_equal(result, expected)155result = (156df.sort("times", descending=True)157.with_columns(pl.col("values").interpolate_by("times"))158.sort("times")159.drop("times")160)161assert_frame_equal(result, expected)162163164@given(data=st.data(), x_dtype=st.sampled_from([pl.Date, pl.Float64]))165def test_interpolate_vs_numpy(data: st.DataObject, x_dtype: pl.DataType) -> None:166if x_dtype == pl.Float64:167by_strategy = st.floats(168min_value=-1e150,169max_value=1e150,170allow_nan=False,171allow_infinity=False,172allow_subnormal=False,173)174else:175by_strategy = None176177dataframe = (178data.draw(179dataframes(180[181column(182"ts",183dtype=x_dtype,184allow_null=False,185strategy=by_strategy,186),187column(188"value",189dtype=pl.Float64,190allow_null=True,191),192],193min_size=1,194)195)196.sort("ts")197.fill_nan(None)198.unique("ts")199)200201if x_dtype == pl.Float64:202assume(not dataframe["ts"].is_nan().any())203assume(not dataframe["ts"].is_null().any())204assume(not dataframe["ts"].is_in([float("-inf"), float("inf")]).any())205206assume(not dataframe["value"].is_null().all())207assume(not dataframe["value"].is_in([float("-inf"), float("inf")]).any())208209dataframe = dataframe.sort("ts")210211result = dataframe.select(pl.col("value").interpolate_by("ts"))["value"]212213mask = dataframe["value"].is_not_null()214215np_dtype = "int64" if x_dtype == pl.Date else "float64"216x = dataframe["ts"].to_numpy().astype(np_dtype)217xp = dataframe["ts"].filter(mask).to_numpy().astype(np_dtype)218yp = dataframe["value"].filter(mask).to_numpy().astype("float64")219interp = np.interp(x, xp, yp)220# Polars preserves nulls on boundaries, but NumPy doesn't.221first_non_null = dataframe["value"].is_not_null().arg_max()222last_non_null = len(dataframe) - dataframe["value"][::-1].is_not_null().arg_max() # type: ignore[operator]223interp[:first_non_null] = float("nan")224interp[last_non_null:] = float("nan")225expected = dataframe.with_columns(value=pl.Series(interp, nan_to_null=True))[226"value"227]228229# We increase the absolute error threshold, numpy has some instability, see #22348.230assert_series_equal(result, expected, abs_tol=1e-4)231result_from_unsorted = (232dataframe.sort("ts", descending=True)233.with_columns(pl.col("value").interpolate_by("ts"))234.sort("ts")["value"]235)236assert_series_equal(result_from_unsorted, expected, abs_tol=1e-4)237238239def test_interpolate_by_invalid() -> None:240s = pl.Series([1, None, 3])241by = pl.Series([1, 2])242with pytest.raises(InvalidOperationError, match=r"\(3\), got 2"):243s.interpolate_by(by)244245by = pl.Series([1, None, 3])246with pytest.raises(247InvalidOperationError,248match="null values in `by` column are not yet supported in 'interpolate_by'",249):250s.interpolate_by(by)251252253