Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_clip.py
6939 views
1
from __future__ import annotations
2
3
from datetime import datetime
4
from decimal import Decimal
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import InvalidOperationError
10
from polars.testing import assert_frame_equal, assert_series_equal
11
12
13
@pytest.fixture
14
def clip_exprs() -> list[pl.Expr]:
15
return [
16
pl.col("a").clip(pl.col("min"), pl.col("max")).alias("clip"),
17
pl.col("a").clip(lower_bound=pl.col("min")).alias("clip_min"),
18
pl.col("a").clip(upper_bound=pl.col("max")).alias("clip_max"),
19
]
20
21
22
def test_clip_int(clip_exprs: list[pl.Expr]) -> None:
23
lf = pl.LazyFrame(
24
{
25
"a": [1, 2, 3, 4, 5, None],
26
"min": [0, -1, 4, None, 4, -10],
27
"max": [2, 1, 8, 5, None, 10],
28
}
29
)
30
result = lf.select(clip_exprs)
31
expected = pl.LazyFrame(
32
{
33
"clip": [1, 1, 4, 4, 5, None],
34
"clip_min": [1, 2, 4, 4, 5, None],
35
"clip_max": [1, 1, 3, 4, 5, None],
36
}
37
)
38
assert_frame_equal(result, expected)
39
40
41
def test_clip_float(clip_exprs: list[pl.Expr]) -> None:
42
lf = pl.LazyFrame(
43
{
44
"a": [1.0, 2.0, 3.0, 4.0, 5.0, None],
45
"min": [0.0, -1.0, 4.0, None, 4.0, None],
46
"max": [2.0, 1.0, 8.0, 5.0, None, None],
47
}
48
)
49
result = lf.select(clip_exprs)
50
expected = pl.LazyFrame(
51
{
52
"clip": [1.0, 1.0, 4.0, 4.0, 5.0, None],
53
"clip_min": [1.0, 2.0, 4.0, 4.0, 5.0, None],
54
"clip_max": [1.0, 1.0, 3.0, 4.0, 5.0, None],
55
}
56
)
57
assert_frame_equal(result, expected)
58
59
60
def test_clip_datetime(clip_exprs: list[pl.Expr]) -> None:
61
lf = pl.LazyFrame(
62
{
63
"a": [
64
datetime(1995, 6, 5, 10, 30),
65
datetime(1995, 6, 5),
66
datetime(2023, 10, 20, 18, 30, 6),
67
None,
68
datetime(2023, 9, 24),
69
datetime(2000, 1, 10),
70
],
71
"min": [
72
datetime(1995, 6, 5, 10, 29),
73
datetime(1996, 6, 5),
74
datetime(2020, 9, 24),
75
datetime(2020, 1, 1),
76
None,
77
datetime(2000, 1, 1),
78
],
79
"max": [
80
datetime(1995, 7, 21, 10, 30),
81
datetime(2000, 1, 1),
82
datetime(2023, 9, 20, 18, 30, 6),
83
datetime(2000, 1, 1),
84
datetime(1993, 3, 13),
85
None,
86
],
87
}
88
)
89
result = lf.select(clip_exprs)
90
expected = pl.LazyFrame(
91
{
92
"clip": [
93
datetime(1995, 6, 5, 10, 30),
94
datetime(1996, 6, 5),
95
datetime(2023, 9, 20, 18, 30, 6),
96
None,
97
datetime(1993, 3, 13),
98
datetime(2000, 1, 10),
99
],
100
"clip_min": [
101
datetime(1995, 6, 5, 10, 30),
102
datetime(1996, 6, 5),
103
datetime(2023, 10, 20, 18, 30, 6),
104
None,
105
datetime(2023, 9, 24),
106
datetime(2000, 1, 10),
107
],
108
"clip_max": [
109
datetime(1995, 6, 5, 10, 30),
110
datetime(1995, 6, 5),
111
datetime(2023, 9, 20, 18, 30, 6),
112
None,
113
datetime(1993, 3, 13),
114
datetime(2000, 1, 10),
115
],
116
}
117
)
118
assert_frame_equal(result, expected)
119
120
121
def test_clip_non_numeric_dtype_fails() -> None:
122
msg = "`clip` only supports physical numeric types"
123
124
s = pl.Series(["a", "b", "c"])
125
with pytest.raises(InvalidOperationError, match=msg):
126
s.clip(pl.lit("b"), pl.lit("z"))
127
128
129
def test_clip_string_input() -> None:
130
df = pl.DataFrame({"a": [0, 1, 2], "min": [1, None, 1]})
131
result = df.select(pl.col("a").clip("min"))
132
expected = pl.DataFrame({"a": [1, 1, 2]})
133
assert_frame_equal(result, expected)
134
135
136
def test_clip_bound_invalid_for_original_dtype() -> None:
137
s = pl.Series([1, 2, 3, 4], dtype=pl.UInt32)
138
with pytest.raises(
139
InvalidOperationError, match="conversion from `i32` to `u32` failed"
140
):
141
s.clip(-1, 5)
142
143
144
def test_clip_decimal() -> None:
145
ser = pl.Series("a", ["1.1", "2.2", "3.3"], pl.Decimal(21, 1))
146
147
result = ser.clip(lower_bound=Decimal("1.5"), upper_bound=Decimal("2.5"))
148
expected = pl.Series("a", ["1.5", "2.2", "2.5"], pl.Decimal(21, 1))
149
assert_series_equal(result, expected)
150
151
result = ser.clip(lower_bound=Decimal("1.5"))
152
expected = pl.Series("a", ["1.5", "2.2", "3.3"], pl.Decimal(21, 1))
153
assert_series_equal(result, expected)
154
155
result = ser.clip(upper_bound=Decimal("2.5"))
156
expected = pl.Series("a", ["1.1", "2.2", "2.5"], pl.Decimal(21, 1))
157
assert_series_equal(result, expected)
158
159
160
def test_clip_unequal_lengths_22018() -> None:
161
with pytest.raises(pl.exceptions.ShapeError):
162
pl.Series([1, 2, 3]).clip(lower_bound=pl.Series([1, 2]))
163
with pytest.raises(pl.exceptions.ShapeError):
164
pl.Series([1, 2, 3]).clip(upper_bound=pl.Series([1, 2]))
165
with pytest.raises(pl.exceptions.ShapeError):
166
pl.Series([1, 2, 3]).clip(pl.Series([1, 2]), pl.Series([1, 2, 3]))
167
with pytest.raises(pl.exceptions.ShapeError):
168
pl.Series([1, 2, 3]).clip(pl.Series([1, 2, 3]), pl.Series([1, 2]))
169
170