Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/test_api.py
6939 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import pytest
6
7
import polars as pl
8
from polars.testing import assert_frame_equal
9
10
11
def test_custom_df_namespace() -> None:
12
@pl.api.register_dataframe_namespace("split")
13
class SplitFrame:
14
def __init__(self, df: pl.DataFrame) -> None:
15
self._df = df
16
17
def by_first_letter_of_column_names(self) -> list[pl.DataFrame]:
18
return [
19
self._df.select([col for col in self._df.columns if col[0] == f])
20
for f in sorted({col[0] for col in self._df.columns})
21
]
22
23
def by_first_letter_of_column_values(self, col: str) -> list[pl.DataFrame]:
24
return [
25
self._df.filter(pl.col(col).str.starts_with(c))
26
for c in sorted(set(df.select(pl.col(col).str.slice(0, 1)).to_series()))
27
]
28
29
df = pl.DataFrame(
30
data=[["xx", 2, 3, 4], ["xy", 4, 5, 6], ["yy", 5, 6, 7], ["yz", 6, 7, 8]],
31
schema=["a1", "a2", "b1", "b2"],
32
orient="row",
33
)
34
35
dfs = df.split.by_first_letter_of_column_names() # type: ignore[attr-defined]
36
assert [d.rows() for d in dfs] == [
37
[("xx", 2), ("xy", 4), ("yy", 5), ("yz", 6)],
38
[(3, 4), (5, 6), (6, 7), (7, 8)],
39
]
40
dfs = df.split.by_first_letter_of_column_values("a1") # type: ignore[attr-defined]
41
assert [d.rows() for d in dfs] == [
42
[("xx", 2, 3, 4), ("xy", 4, 5, 6)],
43
[("yy", 5, 6, 7), ("yz", 6, 7, 8)],
44
]
45
46
47
def test_custom_expr_namespace() -> None:
48
@pl.api.register_expr_namespace("power")
49
class PowersOfN:
50
def __init__(self, expr: pl.Expr) -> None:
51
self._expr = expr
52
53
def next(self, p: int) -> pl.Expr:
54
return (p ** (self._expr.log(p).ceil()).cast(pl.Int64)).cast(pl.Int64)
55
56
def previous(self, p: int) -> pl.Expr:
57
return (p ** (self._expr.log(p).floor()).cast(pl.Int64)).cast(pl.Int64)
58
59
def nearest(self, p: int) -> pl.Expr:
60
return (p ** (self._expr.log(p)).round(0).cast(pl.Int64)).cast(pl.Int64)
61
62
df = pl.DataFrame([1.4, 24.3, 55.0, 64.001], schema=["n"])
63
assert df.select(
64
pl.col("n"),
65
pl.col("n").power.next(p=2).alias("next_pow2"), # type: ignore[attr-defined]
66
pl.col("n").power.previous(p=2).alias("prev_pow2"), # type: ignore[attr-defined]
67
pl.col("n").power.nearest(p=2).alias("nearest_pow2"), # type: ignore[attr-defined]
68
).rows() == [
69
(1.4, 2, 1, 1),
70
(24.3, 32, 16, 32),
71
(55.0, 64, 32, 64),
72
(64.001, 128, 64, 64),
73
]
74
75
76
def test_custom_lazy_namespace() -> None:
77
@pl.api.register_lazyframe_namespace("split")
78
class SplitFrame:
79
def __init__(self, lf: pl.LazyFrame) -> None:
80
self._lf = lf
81
82
def by_column_dtypes(self) -> list[pl.LazyFrame]:
83
return [
84
self._lf.select(pl.col(tp))
85
for tp in dict.fromkeys(self._lf.collect_schema().dtypes())
86
]
87
88
ldf = pl.DataFrame(
89
data=[["xx", 2, 3, 4], ["xy", 4, 5, 6], ["yy", 5, 6, 7], ["yz", 6, 7, 8]],
90
schema=["a1", "a2", "b1", "b2"],
91
orient="row",
92
).lazy()
93
94
df1, df2 = (d.collect() for d in ldf.split.by_column_dtypes()) # type: ignore[attr-defined]
95
assert_frame_equal(
96
df1,
97
pl.DataFrame([("xx",), ("xy",), ("yy",), ("yz",)], schema=["a1"], orient="row"),
98
)
99
assert_frame_equal(
100
df2,
101
pl.DataFrame(
102
[(2, 3, 4), (4, 5, 6), (5, 6, 7), (6, 7, 8)],
103
schema=["a2", "b1", "b2"],
104
orient="row",
105
),
106
)
107
108
109
def test_custom_series_namespace() -> None:
110
@pl.api.register_series_namespace("math")
111
class CustomMath:
112
def __init__(self, s: pl.Series) -> None:
113
self._s = s
114
115
def square(self) -> pl.Series:
116
return self._s * self._s
117
118
s = pl.Series("n", [1.5, 31.0, 42.0, 64.5])
119
assert s.math.square().to_list() == [ # type: ignore[attr-defined]
120
2.25,
121
961.0,
122
1764.0,
123
4160.25,
124
]
125
126
127
@pytest.mark.slow
128
@pytest.mark.parametrize("pcls", [pl.Expr, pl.DataFrame, pl.LazyFrame, pl.Series])
129
def test_class_namespaces_are_registered(pcls: Any) -> None:
130
# confirm that existing (and new) namespaces
131
# have been added to that class's "_accessors" attr
132
namespaces: set[str] = getattr(pcls, "_accessors", set())
133
for name in dir(pcls):
134
if not name.startswith("_"):
135
attr = getattr(pcls, name)
136
if isinstance(attr, property):
137
try:
138
obj = attr.fget(pcls) # type: ignore[misc]
139
except Exception:
140
continue
141
142
if obj.__class__.__name__.endswith("NameSpace"):
143
ns = obj._accessor
144
assert ns in namespaces, (
145
f"{ns!r} should be registered in {pcls.__name__}._accessors"
146
)
147
148
149
def test_namespace_cannot_override_builtin() -> None:
150
with pytest.raises(AttributeError):
151
152
@pl.api.register_dataframe_namespace("dt")
153
class CustomDt:
154
def __init__(self, df: pl.DataFrame) -> None:
155
self._df = df
156
157
158
def test_namespace_warning_on_override() -> None:
159
@pl.api.register_dataframe_namespace("math")
160
class CustomMath:
161
def __init__(self, df: pl.DataFrame) -> None:
162
self._df = df
163
164
with pytest.raises(UserWarning):
165
166
@pl.api.register_dataframe_namespace("math")
167
class CustomMath2:
168
def __init__(self, df: pl.DataFrame) -> None:
169
self._df = df
170
171