Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/test_categories.py
6939 views
1
from __future__ import annotations
2
3
import io
4
import subprocess
5
import sys
6
7
import pytest
8
9
import polars as pl
10
from polars.exceptions import ComputeError, SchemaError
11
from polars.testing import assert_frame_equal, assert_series_equal
12
13
CATS = [
14
pl.Categories(),
15
pl.Categories("foo"),
16
pl.Categories("foo", physical=pl.UInt16),
17
pl.Categories("boo"),
18
pl.Categories("foo", "bar"),
19
pl.Categories("foo", "baz"),
20
pl.Categories("foo", "bar", physical=pl.UInt8),
21
pl.Categories.random(),
22
]
23
24
25
def test_categories_eq_hash() -> None:
26
left = CATS
27
right = [pl.Categories(c.name(), c.namespace(), c.physical()) for c in CATS]
28
29
for lc, rc in zip(left, right):
30
assert hash(lc) == hash(rc)
31
assert lc == rc
32
33
for i in range(len(left)):
34
for j in range(len(right)):
35
if i != j:
36
assert left[i] != right[j]
37
38
39
@pytest.mark.parametrize("cats", CATS)
40
def test_cat_parquet_roundtrip(cats: pl.Categories) -> None:
41
df = pl.DataFrame({"x": ["foo", "bar", "moo"]}, schema={"x": pl.Categorical(cats)})
42
f = io.BytesIO()
43
df.write_parquet(f)
44
del df # Delete frame holding reference.
45
f.seek(0)
46
df2 = pl.scan_parquet(f).collect()
47
df = pl.DataFrame({"x": ["foo", "bar", "moo"]}, schema={"x": pl.Categorical(cats)})
48
assert_frame_equal(df, df2)
49
50
51
@pytest.mark.may_fail_cloud # reason: these are not seen locally
52
def test_local_categories_gc() -> None:
53
dt = pl.Categorical(pl.Categories.random())
54
df = pl.DataFrame({"x": ["foo", "bar", "moo"]}, schema={"x": dt})
55
assert set(df["x"].cat.get_categories()) == {"foo", "bar", "moo"}
56
df2 = pl.DataFrame({"x": ["zoinks"]}, schema={"x": dt})
57
assert set(df["x"].cat.get_categories()) == {"foo", "bar", "moo", "zoinks"}
58
assert set(df2["x"].cat.get_categories()) == {"foo", "bar", "moo", "zoinks"}
59
del df
60
del df2
61
df = pl.DataFrame({"x": ["a"]}, schema={"x": dt})
62
assert df["x"].cat.get_categories().to_list() == ["a"]
63
64
65
@pytest.mark.parametrize("cats", CATS)
66
def test_categories_lookup(cats: pl.Categories) -> None:
67
vals = ["foo", "bar", None, "moo", "bar", "moo", "foo", None]
68
df = pl.DataFrame({"x": vals}, schema={"x": pl.Categorical(cats)})
69
cat_vals = pl.Series("x", [cats[v] for v in vals], dtype=cats.physical())
70
assert_series_equal(cat_vals, df["x"].cast(cats.physical()))
71
cat_strs = pl.Series("x", [cats[v] for v in cat_vals])
72
assert_series_equal(cat_strs, df["x"].cast(pl.String))
73
74
75
def test_concat_cat_mismatch() -> None:
76
dt1 = pl.Categorical(pl.Categories.random())
77
dt2 = pl.Categorical(pl.Categories.random())
78
df1 = pl.DataFrame({"x": ["a", "b", "c"]}, schema={"x": dt1})
79
df2 = pl.DataFrame({"x": ["d", "e", "f"]}, schema={"x": dt1})
80
df12 = pl.DataFrame({"x": ["a", "b", "c", "d", "e", "f"]}, schema={"x": dt1})
81
df3 = pl.DataFrame({"x": ["g", "h", "i"]}, schema={"x": dt2})
82
df4 = pl.DataFrame({"x": ["j", "k", "l"]}, schema={"x": dt2})
83
df34 = pl.DataFrame({"x": ["g", "h", "i", "j", "k", "l"]}, schema={"x": dt2})
84
85
assert_frame_equal(pl.concat([df1, df2]), df12)
86
assert_frame_equal(pl.concat([df3, df4]), df34)
87
88
for left in [df1, df2]:
89
for right in [df3, df4]:
90
with pytest.raises(SchemaError):
91
pl.concat([left, right])
92
93
for li in range(len(CATS)):
94
for ri in range(len(CATS)):
95
if li == ri:
96
continue
97
98
ldf = pl.DataFrame({"x": []}, schema={"x": pl.Categorical(CATS[li])})
99
rdf = pl.DataFrame({"x": []}, schema={"x": pl.Categorical(CATS[ri])})
100
with pytest.raises(SchemaError):
101
pl.concat([ldf, rdf])
102
103
104
def test_cat_overflow() -> None:
105
c = pl.Categories.random(physical=pl.UInt8)
106
str_s = pl.Series(range(255)).cast(pl.String)
107
cat_s = str_s.cast(pl.Categorical(c))
108
assert_series_equal(str_s, cat_s.cast(pl.String))
109
with pytest.raises(ComputeError):
110
pl.Series(["boom"], dtype=pl.Categorical(c))
111
112
113
def test_global_categories_gc() -> None:
114
# We run this in a subprocess to ensure no other test is or has been messing
115
# with the global categories.
116
out = subprocess.check_output(
117
[
118
sys.executable,
119
"-c",
120
"""\
121
import polars as pl
122
123
df = pl.DataFrame({"x": ["a", "b", "c"]}, schema={"x": pl.Categorical})
124
assert set(df["x"].cat.get_categories().to_list()) == {"a", "b", "c"}
125
df2 = pl.DataFrame({"x": ["d", "e", "f"]}, schema={"x": pl.Categorical})
126
assert set(df["x"].cat.get_categories().to_list()) == {"a", "b", "c", "d", "e", "f"}
127
del df
128
del df2
129
df3 = pl.DataFrame({"x": ["x"]}, schema={"x": pl.Categorical})
130
assert set(df3["x"].cat.get_categories().to_list()) == {"x"}
131
del df3
132
133
keep_alive = pl.DataFrame({"x": []}, schema={"x": pl.Categorical})
134
df = pl.DataFrame({"x": ["a", "b", "c"]}, schema={"x": pl.Categorical})
135
assert set(df["x"].cat.get_categories().to_list()) == {"a", "b", "c"}
136
df2 = pl.DataFrame({"x": ["d", "e", "f"]}, schema={"x": pl.Categorical})
137
assert set(df["x"].cat.get_categories().to_list()) == {"a", "b", "c", "d", "e", "f"}
138
del df
139
del df2
140
df3 = pl.DataFrame({"x": ["x"]}, schema={"x": pl.Categorical})
141
assert set(df3["x"].cat.get_categories().to_list()) == {"a", "b", "c", "d", "e", "f", "x"}
142
143
print("OK", end="")
144
""",
145
],
146
)
147
148
assert out == b"OK"
149
150