Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/namespaces/test_categorical.py
6940 views
1
from __future__ import annotations
2
3
from io import BytesIO
4
from typing import TYPE_CHECKING
5
6
import pytest
7
8
import polars as pl
9
from polars.testing import assert_frame_equal, assert_series_equal
10
11
if TYPE_CHECKING:
12
from polars._typing import PolarsDataType
13
14
15
def test_categorical_lexical_sort() -> None:
16
df = pl.DataFrame(
17
{"cats": ["z", "z", "k", "a", "b"], "vals": [3, 1, 2, 2, 3]}
18
).with_columns(
19
pl.col("cats").cast(pl.Categorical("lexical")),
20
)
21
22
out = df.sort(["cats"])
23
assert out["cats"].dtype == pl.Categorical
24
expected = pl.DataFrame(
25
{"cats": ["a", "b", "k", "z", "z"], "vals": [2, 3, 2, 3, 1]}
26
)
27
assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)
28
out = df.sort(["cats", "vals"])
29
expected = pl.DataFrame(
30
{"cats": ["a", "b", "k", "z", "z"], "vals": [2, 3, 2, 1, 3]}
31
)
32
assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)
33
out = df.sort(["vals", "cats"])
34
35
expected = pl.DataFrame(
36
{"cats": ["z", "a", "k", "b", "z"], "vals": [1, 2, 2, 3, 3]}
37
)
38
assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)
39
40
s = pl.Series(["a", "c", "a", "b", "a"], dtype=pl.Categorical("lexical"))
41
assert s.sort().cast(pl.String).to_list() == [
42
"a",
43
"a",
44
"a",
45
"b",
46
"c",
47
]
48
49
50
def test_categorical_lexical_ordering_after_concat() -> None:
51
ldf1 = (
52
pl.DataFrame([pl.Series("key1", [8, 5]), pl.Series("key2", ["fox", "baz"])])
53
.lazy()
54
.with_columns(pl.col("key2").cast(pl.Categorical("lexical")))
55
)
56
ldf2 = (
57
pl.DataFrame(
58
[pl.Series("key1", [6, 8, 6]), pl.Series("key2", ["fox", "foo", "bar"])]
59
)
60
.lazy()
61
.with_columns(pl.col("key2").cast(pl.Categorical("lexical")))
62
)
63
df = pl.concat([ldf1, ldf2]).select(pl.col("key2")).collect()
64
65
assert df.sort("key2").to_dict(as_series=False) == {
66
"key2": ["bar", "baz", "foo", "fox", "fox"]
67
}
68
69
70
def test_sort_categoricals_6014_lexical() -> None:
71
# create lexically-ordered categorical
72
df = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns(
73
pl.col("key").cast(pl.Categorical("lexical"))
74
)
75
76
out = df.sort("key")
77
assert out.to_dict(as_series=False) == {"key": ["aaa", "bbb", "ccc"]}
78
79
80
def test_categorical_get_categories() -> None:
81
s = pl.Series("cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical)
82
assert set(s.cat.get_categories().to_list()) >= {"foo", "bar", "ham"}
83
84
85
def test_cat_to_local() -> None:
86
s = pl.Series(["a", "b", "a"], dtype=pl.Categorical)
87
assert_series_equal(s, s.cat.to_local())
88
89
90
def test_cat_uses_lexical_ordering() -> None:
91
s = pl.Series(["a", "b", None, "b"]).cast(pl.Categorical)
92
assert s.cat.uses_lexical_ordering()
93
94
s = s.cast(pl.Categorical("lexical"))
95
assert s.cat.uses_lexical_ordering()
96
97
with pytest.warns(DeprecationWarning):
98
s = s.cast(pl.Categorical("physical")) # Deprecated.
99
assert s.cat.uses_lexical_ordering()
100
101
102
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])
103
def test_cat_len_bytes(dtype: PolarsDataType) -> None:
104
# test Series
105
values = ["Café", None, "Café", "345", "東京"]
106
if dtype == pl.Enum:
107
dtype = pl.Enum(list({x for x in values if x is not None}))
108
s = pl.Series("a", values, dtype=dtype)
109
result = s.cat.len_bytes()
110
expected = pl.Series("a", [5, None, 5, 3, 6], dtype=pl.UInt32)
111
assert_series_equal(result, expected)
112
113
# test DataFrame expr
114
df = pl.DataFrame(s)
115
result_df = df.select(pl.col("a").cat.len_bytes())
116
expected_df = pl.DataFrame(expected)
117
assert_frame_equal(result_df, expected_df)
118
119
# test LazyFrame expr
120
result_lf = df.lazy().select(pl.col("a").cat.len_bytes()).collect()
121
assert_frame_equal(result_lf, expected_df)
122
123
# test GroupBy
124
result_df = (
125
pl.LazyFrame({"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "value": s.extend(s)})
126
.group_by("key", maintain_order=True)
127
.agg(pl.col("value").cat.len_bytes().alias("len_bytes"))
128
.explode("len_bytes")
129
.collect()
130
)
131
expected_df = pl.DataFrame(
132
{
133
"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
134
"len_bytes": pl.Series(
135
[5, None, 5, 3, 6, 5, None, 5, 3, 6], dtype=pl.get_index_type()
136
),
137
}
138
)
139
assert_frame_equal(result_df, expected_df)
140
141
142
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])
143
def test_cat_len_chars(dtype: PolarsDataType) -> None:
144
values = ["Café", None, "Café", "345", "東京"]
145
if dtype == pl.Enum:
146
dtype = pl.Enum(list({x for x in values if x is not None}))
147
# test Series
148
s = pl.Series("a", values, dtype=dtype)
149
result = s.cat.len_chars()
150
expected = pl.Series("a", [4, None, 4, 3, 2], dtype=pl.UInt32)
151
assert_series_equal(result, expected)
152
153
# test DataFrame expr
154
df = pl.DataFrame(s)
155
result_df = df.select(pl.col("a").cat.len_chars())
156
expected_df = pl.DataFrame(expected)
157
assert_frame_equal(result_df, expected_df)
158
159
# test LazyFrame expr
160
result_lf = df.lazy().select(pl.col("a").cat.len_chars()).collect()
161
assert_frame_equal(result_lf, expected_df)
162
163
# test GroupBy
164
result_df = (
165
pl.LazyFrame({"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "value": s.extend(s)})
166
.group_by("key", maintain_order=True)
167
.agg(pl.col("value").cat.len_chars().alias("len_bytes"))
168
.explode("len_bytes")
169
.collect()
170
)
171
expected_df = pl.DataFrame(
172
{
173
"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
174
"len_bytes": pl.Series(
175
[4, None, 4, 3, 2, 4, None, 4, 3, 2], dtype=pl.get_index_type()
176
),
177
}
178
)
179
assert_frame_equal(result_df, expected_df)
180
181
182
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])
183
def test_starts_ends_with(dtype: PolarsDataType) -> None:
184
values = ["hamburger_with_tomatoes", "nuts", "nuts", "lollypop", None]
185
if dtype == pl.Enum:
186
dtype = pl.Enum(list({x for x in values if x is not None}))
187
s = pl.Series("a", values, dtype=dtype)
188
assert_series_equal(
189
s.cat.ends_with("pop"), pl.Series("a", [False, False, False, True, None])
190
)
191
assert_series_equal(
192
s.cat.starts_with("nu"), pl.Series("a", [False, True, True, False, None])
193
)
194
195
with pytest.raises(TypeError, match="'prefix' must be a string; found"):
196
s.cat.starts_with(None) # type: ignore[arg-type]
197
198
with pytest.raises(TypeError, match="'suffix' must be a string; found"):
199
s.cat.ends_with(None) # type: ignore[arg-type]
200
201
df = pl.DataFrame({"a": pl.Series(values, dtype=dtype)})
202
203
expected = {
204
"ends_pop": [False, False, False, True, None],
205
"starts_ham": [True, False, False, False, None],
206
}
207
208
assert (
209
df.select(
210
pl.col("a").cat.ends_with("pop").alias("ends_pop"),
211
pl.col("a").cat.starts_with("ham").alias("starts_ham"),
212
).to_dict(as_series=False)
213
== expected
214
)
215
216
with pytest.raises(TypeError, match="'prefix' must be a string; found"):
217
df.select(pl.col("a").cat.starts_with(None)) # type: ignore[arg-type]
218
219
with pytest.raises(TypeError, match="'suffix' must be a string; found"):
220
df.select(pl.col("a").cat.ends_with(None)) # type: ignore[arg-type]
221
222
223
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])
224
def test_cat_slice(dtype: PolarsDataType) -> None:
225
values = ["foobar", "barfoo", "foobar", "x", None]
226
if dtype == pl.Enum:
227
dtype = pl.Enum(list({x for x in values if x is not None}))
228
df = pl.DataFrame({"a": pl.Series(values, dtype=dtype)})
229
assert df["a"].cat.slice(-3).to_list() == ["bar", "foo", "bar", "x", None]
230
assert df.select([pl.col("a").cat.slice(2, 4)])["a"].to_list() == [
231
"obar",
232
"rfoo",
233
"obar",
234
"",
235
None,
236
]
237
238
239
def test_cat_order_flag_csv_read_23823() -> None:
240
data = BytesIO(b"colx,coly\nabc,123\n#not_a_row\nxyz,456")
241
lf = pl.scan_csv(
242
source=data,
243
comment_prefix="#",
244
schema_overrides={"colx": pl.Categorical},
245
)
246
expected = pl.DataFrame(
247
{"colx": ["abc", "xyz"], "coly": [123, 456]},
248
schema_overrides={"colx": pl.Categorical},
249
)
250
assert_frame_equal(expected, lf.sort("colx", descending=False).collect())
251
252