Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_interpolate.py
8422 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, time, timedelta
4
from typing import TYPE_CHECKING, Any
5
6
import pytest
7
8
import polars as pl
9
from polars.testing import assert_frame_equal
10
from tests.unit.conftest import NUMERIC_DTYPES
11
12
if TYPE_CHECKING:
13
from polars._typing import InterpolationMethod, PolarsDataType, PolarsTemporalType
14
15
from zoneinfo import ZoneInfo
16
17
18
@pytest.mark.parametrize(
19
("input_dtype", "output_dtype"),
20
[
21
(pl.Int8, pl.Float64),
22
(pl.Int16, pl.Float64),
23
(pl.Int32, pl.Float64),
24
(pl.Int64, pl.Float64),
25
(pl.Int128, pl.Float64),
26
(pl.UInt8, pl.Float64),
27
(pl.UInt16, pl.Float64),
28
(pl.UInt32, pl.Float64),
29
(pl.UInt64, pl.Float64),
30
(pl.UInt128, pl.Float64),
31
(pl.Float32, pl.Float32),
32
(pl.Float64, pl.Float64),
33
],
34
)
35
def test_interpolate_linear(
36
input_dtype: PolarsDataType, output_dtype: PolarsDataType
37
) -> None:
38
df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})
39
result = df.with_columns(pl.all().interpolate(method="linear"))
40
assert result.collect_schema()["a"] == output_dtype
41
expected = pl.DataFrame(
42
{"a": [1.0, 1.5, 2.0, 2.5, 3.0]}, schema={"a": output_dtype}
43
)
44
assert_frame_equal(result.collect(), expected)
45
46
47
@pytest.mark.parametrize(
48
("input", "input_dtype", "output"),
49
[
50
(
51
[date(2020, 1, 1), None, date(2020, 1, 2)],
52
pl.Date,
53
[date(2020, 1, 1), date(2020, 1, 1), date(2020, 1, 2)],
54
),
55
(
56
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
57
pl.Datetime("ms"),
58
[datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)],
59
),
60
(
61
[
62
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
63
None,
64
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
65
],
66
pl.Datetime("us", "Asia/Kathmandu"),
67
[
68
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
69
datetime(2020, 1, 1, 12, tzinfo=ZoneInfo("Asia/Kathmandu")),
70
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
71
],
72
),
73
([time(1), None, time(2)], pl.Time, [time(1), time(1, 30), time(2)]),
74
(
75
[timedelta(1), None, timedelta(2)],
76
pl.Duration("ms"),
77
[timedelta(1), timedelta(1, hours=12), timedelta(2)],
78
),
79
],
80
)
81
def test_interpolate_temporal_linear(
82
input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]
83
) -> None:
84
df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})
85
result = df.with_columns(pl.all().interpolate(method="linear"))
86
assert result.collect_schema()["a"] == input_dtype
87
expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})
88
assert_frame_equal(result.collect(), expected)
89
90
91
@pytest.mark.parametrize("input_dtype", NUMERIC_DTYPES)
92
def test_interpolate_nearest(input_dtype: PolarsDataType) -> None:
93
df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})
94
result = df.with_columns(pl.all().interpolate(method="nearest"))
95
assert result.collect_schema()["a"] == input_dtype
96
expected = pl.DataFrame({"a": [1, 2, 2, 3, 3]}, schema={"a": input_dtype})
97
assert_frame_equal(result.collect(), expected)
98
99
100
@pytest.mark.parametrize(
101
("input", "input_dtype", "output"),
102
[
103
(
104
[date(2020, 1, 1), None, date(2020, 1, 2)],
105
pl.Date,
106
[date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)],
107
),
108
(
109
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
110
pl.Datetime("ms"),
111
[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)],
112
),
113
(
114
[
115
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
116
None,
117
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
118
],
119
pl.Datetime("us", "Asia/Kathmandu"),
120
[
121
datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu")),
122
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
123
datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu")),
124
],
125
),
126
([time(1), None, time(2)], pl.Time, [time(1), time(2), time(2)]),
127
(
128
[timedelta(1), None, timedelta(2)],
129
pl.Duration("ms"),
130
[timedelta(1), timedelta(2), timedelta(2)],
131
),
132
],
133
)
134
def test_interpolate_temporal_nearest(
135
input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]
136
) -> None:
137
df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})
138
result = df.with_columns(pl.all().interpolate(method="nearest"))
139
assert result.collect_schema()["a"] == input_dtype
140
expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})
141
assert_frame_equal(result.collect(), expected)
142
143
144
@pytest.mark.parametrize(
145
("input", "scale", "method", "output"),
146
# note the lack of rounding (1.66 vs 1.67)
147
[
148
([1.0, None, 3.0], 2, "linear", [1.0, 2.0, 3.0]),
149
([1.0, None, None, 2.0], 2, "linear", [1.0, 1.33, 1.66, 2.0]),
150
([1.0, None, 3.0], 2, "nearest", [1.0, 3.0, 3.0]),
151
([1.0, None, None, 2.0], 2, "nearest", [1.0, 1.0, 2.0, 2.0]),
152
],
153
)
154
def test_interpolate_decimal_22475(
155
input: list[Any], scale: int, method: InterpolationMethod, output: list[Any]
156
) -> None:
157
df = pl.DataFrame({"data": input})
158
df_decimal = df.with_columns(pl.col("data").cast(pl.Decimal(scale=scale)))
159
out = df_decimal.with_columns(pl.col("data").interpolate(method=method))
160
expected = pl.DataFrame({"data": output}).with_columns(
161
pl.col("data").cast(pl.Decimal(scale=2))
162
)
163
assert_frame_equal(out, expected)
164
165
q = df_decimal.lazy().with_columns(pl.col("data").interpolate(method=method))
166
assert q.collect_schema() == q.collect().schema
167
168