Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_interpolate.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time, timedelta
4
from typing import TYPE_CHECKING, Any
5
6
import pytest
7
8
import polars as pl
9
from polars.testing import assert_frame_equal
10
from tests.unit.conftest import NUMERIC_DTYPES
11
12
if TYPE_CHECKING:
13
from polars._typing import InterpolationMethod, PolarsDataType, PolarsTemporalType
14
15
from zoneinfo import ZoneInfo
16
17
18
@pytest.mark.parametrize(
19
("input_dtype", "output_dtype"),
20
[
21
(pl.Int8, pl.Float64),
22
(pl.Int16, pl.Float64),
23
(pl.Int32, pl.Float64),
24
(pl.Int64, pl.Float64),
25
(pl.Int128, pl.Float64),
26
(pl.UInt8, pl.Float64),
27
(pl.UInt16, pl.Float64),
28
(pl.UInt32, pl.Float64),
29
(pl.UInt64, pl.Float64),
30
(pl.Float32, pl.Float32),
31
(pl.Float64, pl.Float64),
32
],
33
)
34
def test_interpolate_linear(
35
input_dtype: PolarsDataType, output_dtype: PolarsDataType
36
) -> None:
37
df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})
38
result = df.with_columns(pl.all().interpolate(method="linear"))
39
assert result.collect_schema()["a"] == output_dtype
40
expected = pl.DataFrame(
41
{"a": [1.0, 1.5, 2.0, 2.5, 3.0]}, schema={"a": output_dtype}
42
)
43
assert_frame_equal(result.collect(), expected)
44
45
46
@pytest.mark.parametrize(
47
("input", "input_dtype", "output"),
48
[
49
(
50
[date(2020, 1, 1), None, date(2020, 1, 2)],
51
pl.Date,
52
[date(2020, 1, 1), date(2020, 1, 1), date(2020, 1, 2)],
53
),
54
(
55
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
56
pl.Datetime("ms"),
57
[datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)],
58
),
59
(
60
[
61
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
62
None,
63
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
64
],
65
pl.Datetime("us", "Asia/Kathmandu"),
66
[
67
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
68
datetime(2020, 1, 1, 12, tzinfo=ZoneInfo("Asia/Kathmandu")),
69
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
70
],
71
),
72
([time(1), None, time(2)], pl.Time, [time(1), time(1, 30), time(2)]),
73
(
74
[timedelta(1), None, timedelta(2)],
75
pl.Duration("ms"),
76
[timedelta(1), timedelta(1, hours=12), timedelta(2)],
77
),
78
],
79
)
80
def test_interpolate_temporal_linear(
81
input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]
82
) -> None:
83
df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})
84
result = df.with_columns(pl.all().interpolate(method="linear"))
85
assert result.collect_schema()["a"] == input_dtype
86
expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})
87
assert_frame_equal(result.collect(), expected)
88
89
90
@pytest.mark.parametrize("input_dtype", NUMERIC_DTYPES)
91
def test_interpolate_nearest(input_dtype: PolarsDataType) -> None:
92
df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})
93
result = df.with_columns(pl.all().interpolate(method="nearest"))
94
assert result.collect_schema()["a"] == input_dtype
95
expected = pl.DataFrame({"a": [1, 2, 2, 3, 3]}, schema={"a": input_dtype})
96
assert_frame_equal(result.collect(), expected)
97
98
99
@pytest.mark.parametrize(
100
("input", "input_dtype", "output"),
101
[
102
(
103
[date(2020, 1, 1), None, date(2020, 1, 2)],
104
pl.Date,
105
[date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)],
106
),
107
(
108
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
109
pl.Datetime("ms"),
110
[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)],
111
),
112
(
113
[
114
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
115
None,
116
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
117
],
118
pl.Datetime("us", "Asia/Kathmandu"),
119
[
120
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
121
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
122
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
123
],
124
),
125
([time(1), None, time(2)], pl.Time, [time(1), time(2), time(2)]),
126
(
127
[timedelta(1), None, timedelta(2)],
128
pl.Duration("ms"),
129
[timedelta(1), timedelta(2), timedelta(2)],
130
),
131
],
132
)
133
def test_interpolate_temporal_nearest(
134
input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]
135
) -> None:
136
df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})
137
result = df.with_columns(pl.all().interpolate(method="nearest"))
138
assert result.collect_schema()["a"] == input_dtype
139
expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})
140
assert_frame_equal(result.collect(), expected)
141
142
143
@pytest.mark.parametrize(
144
("input", "scale", "method", "output"),
145
# note the lack of rounding (1.66 vs 1.67)
146
[
147
([1.0, None, 3.0], 2, "linear", [1.0, 2.0, 3.0]),
148
([1.0, None, None, 2.0], 2, "linear", [1.0, 1.33, 1.66, 2.0]),
149
([1.0, None, 3.0], 2, "nearest", [1.0, 3.0, 3.0]),
150
([1.0, None, None, 2.0], 2, "nearest", [1.0, 1.0, 2.0, 2.0]),
151
],
152
)
153
def test_interpolate_decimal_22475(
154
input: list[Any], scale: int, method: InterpolationMethod, output: list[Any]
155
) -> None:
156
df = pl.DataFrame({"data": input})
157
df_decimal = df.with_columns(pl.col("data").cast(pl.Decimal(scale=scale)))
158
out = df_decimal.with_columns(pl.col("data").interpolate(method=method))
159
expected = pl.DataFrame({"data": output}).with_columns(
160
pl.col("data").cast(pl.Decimal(scale=2))
161
)
162
assert_frame_equal(out, expected)
163
164
q = df_decimal.lazy().with_columns(pl.col("data").interpolate(method=method))
165
assert q.collect_schema() == q.collect().schema
166
167