Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_interpolate_by.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date
4
from typing import TYPE_CHECKING
5
6
import hypothesis.strategies as st
7
import numpy as np
8
import pytest
9
from hypothesis import assume, given
10
11
import polars as pl
12
from polars.exceptions import InvalidOperationError
13
from polars.testing import assert_frame_equal, assert_series_equal
14
from polars.testing.parametric import column, dataframes
15
16
if TYPE_CHECKING:
17
from polars._typing import PolarsDataType
18
19
20
@pytest.mark.parametrize(
21
"times_dtype",
22
[
23
pl.Datetime("ms"),
24
pl.Datetime("us", "Asia/Kathmandu"),
25
pl.Datetime("ns"),
26
pl.Date,
27
pl.Int64,
28
pl.Int32,
29
pl.UInt64,
30
pl.UInt32,
31
pl.Float32,
32
pl.Float64,
33
],
34
)
35
@pytest.mark.parametrize(
36
"values_dtype",
37
[
38
pl.Float64,
39
pl.Float32,
40
pl.Int64,
41
pl.Int32,
42
pl.UInt64,
43
pl.UInt32,
44
],
45
)
46
def test_interpolate_by(
47
values_dtype: PolarsDataType, times_dtype: PolarsDataType
48
) -> None:
49
df = pl.DataFrame(
50
{
51
"times": [
52
1,
53
3,
54
10,
55
11,
56
12,
57
16,
58
21,
59
30,
60
],
61
"values": [1, None, None, 5, None, None, None, 6],
62
},
63
schema={"times": times_dtype, "values": values_dtype},
64
)
65
result = df.select(pl.col("values").interpolate_by("times"))
66
expected = pl.DataFrame(
67
{
68
"values": [
69
1.0,
70
1.7999999999999998,
71
4.6,
72
5.0,
73
5.052631578947368,
74
5.2631578947368425,
75
5.526315789473684,
76
6.0,
77
]
78
}
79
)
80
if values_dtype == pl.Float32:
81
expected = expected.select(pl.col("values").cast(pl.Float32))
82
assert_frame_equal(result, expected)
83
result = (
84
df.sort("times", descending=True)
85
.with_columns(pl.col("values").interpolate_by("times"))
86
.sort("times")
87
.drop("times")
88
)
89
assert_frame_equal(result, expected)
90
91
92
def test_interpolate_by_leading_nulls() -> None:
93
df = pl.DataFrame(
94
{
95
"times": [
96
date(2020, 1, 1),
97
date(2020, 1, 1),
98
date(2020, 1, 1),
99
date(2020, 1, 1),
100
date(2020, 1, 3),
101
date(2020, 1, 10),
102
date(2020, 1, 11),
103
],
104
"values": [None, None, None, 1, None, None, 5],
105
}
106
)
107
result = df.select(pl.col("values").interpolate_by("times"))
108
expected = pl.DataFrame({"values": [None, None, None, 1.0, 1.8, 4.6, 5.0]})
109
assert_frame_equal(result, expected)
110
result = (
111
df.sort("times", maintain_order=True, descending=True)
112
.with_columns(pl.col("values").interpolate_by("times"))
113
.sort("times", maintain_order=True)
114
.drop("times")
115
)
116
assert_frame_equal(result, expected, check_exact=False)
117
118
119
@pytest.mark.parametrize("dataset", ["floats", "dates"])
120
def test_interpolate_by_trailing_nulls(dataset: str) -> None:
121
input_data = {
122
"dates": pl.DataFrame(
123
{
124
"times": [
125
date(2020, 1, 1),
126
date(2020, 1, 3),
127
date(2020, 1, 10),
128
date(2020, 1, 11),
129
date(2020, 1, 12),
130
date(2020, 1, 13),
131
],
132
"values": [1, None, None, 5, None, None],
133
}
134
),
135
"floats": pl.DataFrame(
136
{
137
"times": [0.2, 0.4, 0.5, 0.6, 0.9, 1.1],
138
"values": [1, None, None, 5, None, None],
139
}
140
),
141
}
142
143
expected_data = {
144
"dates": pl.DataFrame(
145
{"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}
146
),
147
"floats": pl.DataFrame({"values": [1.0, 3.0, 4.0, 5.0, None, None]}),
148
}
149
150
df = input_data[dataset]
151
expected = expected_data[dataset]
152
153
result = df.select(pl.col("values").interpolate_by("times"))
154
155
assert_frame_equal(result, expected)
156
result = (
157
df.sort("times", descending=True)
158
.with_columns(pl.col("values").interpolate_by("times"))
159
.sort("times")
160
.drop("times")
161
)
162
assert_frame_equal(result, expected)
163
164
165
@given(data=st.data(), x_dtype=st.sampled_from([pl.Date, pl.Float64]))
166
def test_interpolate_vs_numpy(data: st.DataObject, x_dtype: pl.DataType) -> None:
167
if x_dtype == pl.Float64:
168
by_strategy = st.floats(
169
min_value=-1e150,
170
max_value=1e150,
171
allow_nan=False,
172
allow_infinity=False,
173
allow_subnormal=False,
174
)
175
else:
176
by_strategy = None
177
178
dataframe = (
179
data.draw(
180
dataframes(
181
[
182
column(
183
"ts",
184
dtype=x_dtype,
185
allow_null=False,
186
strategy=by_strategy,
187
),
188
column(
189
"value",
190
dtype=pl.Float64,
191
allow_null=True,
192
),
193
],
194
min_size=1,
195
)
196
)
197
.sort("ts")
198
.fill_nan(None)
199
.unique("ts")
200
)
201
202
if x_dtype == pl.Float64:
203
assume(not dataframe["ts"].is_nan().any())
204
assume(not dataframe["ts"].is_null().any())
205
assume(not dataframe["ts"].is_in([float("-inf"), float("inf")]).any())
206
207
assume(not dataframe["value"].is_null().all())
208
assume(not dataframe["value"].is_in([float("-inf"), float("inf")]).any())
209
210
dataframe = dataframe.sort("ts")
211
212
result = dataframe.select(pl.col("value").interpolate_by("ts"))["value"]
213
214
mask = dataframe["value"].is_not_null()
215
216
np_dtype = "int64" if x_dtype == pl.Date else "float64"
217
x = dataframe["ts"].to_numpy().astype(np_dtype)
218
xp = dataframe["ts"].filter(mask).to_numpy().astype(np_dtype)
219
yp = dataframe["value"].filter(mask).to_numpy().astype("float64")
220
interp = np.interp(x, xp, yp)
221
# Polars preserves nulls on boundaries, but NumPy doesn't.
222
first_non_null = dataframe["value"].is_not_null().arg_max()
223
last_non_null = len(dataframe) - dataframe["value"][::-1].is_not_null().arg_max() # type: ignore[operator]
224
interp[:first_non_null] = float("nan")
225
interp[last_non_null:] = float("nan")
226
expected = dataframe.with_columns(value=pl.Series(interp, nan_to_null=True))[
227
"value"
228
]
229
230
# We increase the absolute error threshold, numpy has some instability, see #22348.
231
assert_series_equal(result, expected, abs_tol=1e-4)
232
result_from_unsorted = (
233
dataframe.sort("ts", descending=True)
234
.with_columns(pl.col("value").interpolate_by("ts"))
235
.sort("ts")["value"]
236
)
237
assert_series_equal(result_from_unsorted, expected, abs_tol=1e-4)
238
239
240
def test_interpolate_by_invalid() -> None:
241
s = pl.Series([1, None, 3])
242
by = pl.Series([1, 2])
243
with pytest.raises(InvalidOperationError, match=r"\(3\), got 2"):
244
s.interpolate_by(by)
245
246
by = pl.Series([1, None, 3])
247
with pytest.raises(
248
InvalidOperationError,
249
match="null values in `by` column are not yet supported in 'interpolate_by'",
250
):
251
s.interpolate_by(by)
252
253