Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_cut.py
6939 views
1
from __future__ import annotations
2
3
import pytest
4
5
import polars as pl
6
from polars.testing import assert_frame_equal, assert_series_equal
7
8
inf = float("inf")
9
10
11
def test_cut() -> None:
12
s = pl.Series("a", [-2, -1, 0, 1, 2])
13
14
result = s.cut([-1, 1])
15
16
expected = pl.Series(
17
"a",
18
[
19
"(-inf, -1]",
20
"(-inf, -1]",
21
"(-1, 1]",
22
"(-1, 1]",
23
"(1, inf]",
24
],
25
dtype=pl.Categorical,
26
)
27
assert_series_equal(result, expected, categorical_as_str=True)
28
29
30
def test_cut_lazy_schema() -> None:
31
lf = pl.LazyFrame({"a": [-2, -1, 0, 1, 2]})
32
33
result = lf.select(pl.col("a").cut([-1, 1]))
34
35
expected = pl.LazyFrame(
36
{"a": ["(-inf, -1]", "(-inf, -1]", "(-1, 1]", "(-1, 1]", "(1, inf]"]},
37
schema={"a": pl.Categorical},
38
)
39
assert_frame_equal(result, expected, categorical_as_str=True)
40
41
42
def test_cut_include_breaks() -> None:
43
s = pl.Series("a", [-2, -1, 0, 1, 2])
44
45
out = s.cut([-1.5, 0.25, 1.0], labels=["a", "b", "c", "d"], include_breaks=True)
46
47
expected = pl.DataFrame(
48
{
49
"breakpoint": [-1.5, 0.25, 0.25, 1.0, inf],
50
"category": ["a", "b", "b", "c", "d"],
51
},
52
schema_overrides={"category": pl.Categorical},
53
).to_struct("a")
54
assert_series_equal(out, expected, categorical_as_str=True)
55
56
57
# https://github.com/pola-rs/polars/issues/11255
58
def test_cut_include_breaks_lazy_schema() -> None:
59
lf = pl.LazyFrame({"a": [-2, -1, 0, 1, 2]})
60
61
result = lf.select(
62
pl.col("a").cut([-1, 1], include_breaks=True).alias("cut")
63
).unnest("cut")
64
65
expected = pl.LazyFrame(
66
{
67
"breakpoint": [-1.0, -1.0, 1.0, 1.0, inf],
68
"category": ["(-inf, -1]", "(-inf, -1]", "(-1, 1]", "(-1, 1]", "(1, inf]"],
69
},
70
schema_overrides={"category": pl.Categorical},
71
)
72
assert_frame_equal(result, expected, categorical_as_str=True)
73
74
75
def test_cut_null_values() -> None:
76
s = pl.Series([-1.0, None, 1.0, 2.0, None, 8.0, 4.0])
77
78
result = s.cut([1.5, 5.0], labels=["a", "b", "c"])
79
80
expected = pl.Series(["a", None, "a", "b", None, "c", "b"], dtype=pl.Categorical)
81
assert_series_equal(result, expected, categorical_as_str=True)
82
83
84
def test_cut_bin_name_in_agg_context() -> None:
85
df = pl.DataFrame({"a": [1]}).select(
86
cut=pl.col("a").cut([1, 2], include_breaks=True).over(1),
87
qcut=pl.col("a").qcut([1], include_breaks=True).over(1),
88
qcut_uniform=pl.col("a").qcut(1, include_breaks=True).over(1),
89
)
90
schema = pl.Struct({"breakpoint": pl.Float64, "category": pl.Categorical()})
91
assert df.schema == {"cut": schema, "qcut": schema, "qcut_uniform": schema}
92
93
94
@pytest.mark.parametrize(
95
("breaks", "expected_labels", "expected_unique"),
96
[
97
(
98
[2, 4],
99
pl.Series("x", ["(-inf, 2]", "(-inf, 2]", "(2, 4]", "(2, 4]", "(4, inf]"]),
100
3,
101
),
102
(
103
[99, 101],
104
pl.Series("x", 5 * ["(-inf, 99]"]),
105
1,
106
),
107
],
108
)
109
def test_cut_fast_unique_15981(
110
breaks: list[int],
111
expected_labels: pl.Series,
112
expected_unique: int,
113
) -> None:
114
s = pl.Series("x", [1, 2, 3, 4, 5])
115
116
include_breaks = False
117
s_cut = s.cut(breaks, include_breaks=include_breaks)
118
119
assert_series_equal(s_cut.cast(pl.String), expected_labels)
120
assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique
121
s_cut.to_frame().group_by(s.name).len()
122
123
include_breaks = True
124
s_cut = (
125
s.cut(breaks, include_breaks=include_breaks).struct.field("category").alias("x")
126
)
127
128
assert_series_equal(s_cut.cast(pl.String), expected_labels)
129
assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique
130
s_cut.to_frame().group_by(s.name).len()
131
132