Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_structs.py
6939 views
1
from __future__ import annotations
2
3
import pytest
4
5
import polars as pl
6
from polars.exceptions import (
7
SQLInterfaceError,
8
SQLSyntaxError,
9
StructFieldNotFoundError,
10
)
11
from polars.testing import assert_frame_equal
12
13
14
@pytest.fixture
15
def df_struct() -> pl.DataFrame:
16
return pl.DataFrame(
17
{
18
"id": [200, 300, 400],
19
"name": ["Bob", "David", "Zoe"],
20
"age": [45, 19, 45],
21
"other": [{"n": 1.5}, {"n": None}, {"n": -0.5}],
22
}
23
).select(pl.struct(pl.all()).alias("json_msg"))
24
25
26
def test_struct_field_nested_dot_notation_22107() -> None:
27
# ensure dot-notation references the given name at the right level of nesting
28
df = pl.DataFrame(
29
{
30
"id": ["012345", "987654"],
31
"name": ["A Book", "Another Book"],
32
"author": [
33
{"id": "888888", "name": "Iain M. Banks"},
34
{"id": "444444", "name": "Dan Abnett"},
35
],
36
}
37
)
38
39
res = df.sql("SELECT id, author.id AS author_id FROM self ORDER BY id")
40
assert res.to_dict(as_series=False) == {
41
"id": ["012345", "987654"],
42
"author_id": ["888888", "444444"],
43
}
44
45
for name in ("author.name", "self.author.name"):
46
res = df.sql(f"SELECT {name} FROM self ORDER BY id")
47
assert res.to_dict(as_series=False) == {"name": ["Iain M. Banks", "Dan Abnett"]}
48
49
for name in ("name", "self.name"):
50
res = df.sql(f"SELECT {name} FROM self ORDER BY self.id DESC")
51
assert res.to_dict(as_series=False) == {"name": ["Another Book", "A Book"]}
52
53
# expected errors
54
with pytest.raises(
55
SQLInterfaceError,
56
match="no table or struct column named 'foo' found",
57
):
58
df.sql("SELECT foo.id FROM self ORDER BY id")
59
60
with pytest.raises(
61
SQLInterfaceError,
62
match="no column named 'foo' found",
63
):
64
df.sql("SELECT self.foo FROM self ORDER BY id")
65
66
67
@pytest.mark.parametrize(
68
"order_by",
69
[
70
"ORDER BY json_msg.id DESC",
71
"ORDER BY 2 DESC",
72
"",
73
],
74
)
75
def test_struct_field_selection(order_by: str, df_struct: pl.DataFrame) -> None:
76
res = df_struct.sql(
77
f"""
78
SELECT
79
-- validate table alias resolution
80
frame.json_msg.id AS ID,
81
self.json_msg.name AS NAME,
82
json_msg.age AS AGE
83
FROM
84
self AS frame
85
WHERE
86
json_msg.age > 20 AND
87
json_msg.other.n IS NOT NULL -- note: nested struct field
88
{order_by}
89
"""
90
)
91
if not order_by:
92
res = res.sort(by="ID", descending=True)
93
94
expected = pl.DataFrame({"ID": [400, 200], "NAME": ["Zoe", "Bob"], "AGE": [45, 45]})
95
assert_frame_equal(expected, res)
96
97
98
def test_struct_field_group_by(df_struct: pl.DataFrame) -> None:
99
res = pl.sql(
100
"""
101
SELECT
102
COUNT(json_msg.age) AS n,
103
ARRAY_AGG(json_msg.name) AS names
104
FROM df_struct
105
GROUP BY json_msg.age
106
ORDER BY 1 DESC
107
"""
108
).collect()
109
110
expected = pl.DataFrame(
111
data={"n": [2, 1], "names": [["Bob", "Zoe"], ["David"]]},
112
schema_overrides={"n": pl.UInt32},
113
)
114
assert_frame_equal(expected, res)
115
116
117
def test_struct_field_group_by_errors(df_struct: pl.DataFrame) -> None:
118
with pytest.raises(
119
SQLSyntaxError,
120
match="'name' should participate in the GROUP BY clause or an aggregate function",
121
):
122
pl.sql(
123
"""
124
SELECT
125
json_msg.name,
126
SUM(json_msg.age) AS sum_age
127
FROM df_struct
128
GROUP BY json_msg.age
129
"""
130
).collect()
131
132
133
@pytest.mark.parametrize(
134
("expr", "expected"),
135
[
136
("nested #> '{c,1}'", 2),
137
("nested #> '{c,-1}'", 1),
138
("nested #>> '{c,0}'", "3"),
139
("nested -> '0' -> 0", "baz"),
140
("nested -> 'c' -> -1", 1),
141
("nested -> 'c' ->> 2", "1"),
142
],
143
)
144
def test_struct_field_operator_access(expr: str, expected: int | str) -> None:
145
df = pl.DataFrame(
146
{
147
"nested": {
148
"0": ["baz"],
149
"b": ["foo", "bar"],
150
"c": [3, 2, 1],
151
},
152
},
153
)
154
assert df.sql(f"SELECT {expr} FROM self").item() == expected
155
156
157
@pytest.mark.parametrize(
158
("fields", "excluding", "rename"),
159
[
160
("json_msg.*", "age", {}),
161
("json_msg.*", "name", {"other": "misc"}),
162
("self.json_msg.*", "(age,other)", {"name": "ident"}),
163
("json_msg.other.*", "", {"n": "num"}),
164
("self.json_msg.other.*", "", {}),
165
("self.json_msg.other.*", "n", {}),
166
],
167
)
168
def test_struct_field_selection_wildcards(
169
fields: str,
170
excluding: str,
171
rename: dict[str, str],
172
df_struct: pl.DataFrame,
173
) -> None:
174
exclude_cols = f"EXCLUDE {excluding}" if excluding else ""
175
rename_cols = (
176
f"RENAME ({','.join(f'{k} AS {v}' for k, v in rename.items())})"
177
if rename
178
else ""
179
)
180
res = df_struct.sql(
181
f"""
182
SELECT {fields} {exclude_cols} {rename_cols}
183
FROM self ORDER BY json_msg.id
184
"""
185
)
186
187
expected = df_struct.unnest("json_msg")
188
if fields.endswith(".other.*"):
189
expected = expected["other"].struct.unnest()
190
if excluding:
191
expected = expected.drop(excluding.strip(")(").split(","))
192
if rename:
193
expected = expected.rename(rename)
194
195
assert_frame_equal(expected, res)
196
197
198
@pytest.mark.parametrize(
199
("invalid_column", "error_type"),
200
[
201
("json_msg.invalid_column", StructFieldNotFoundError),
202
("json_msg.other.invalid_column", StructFieldNotFoundError),
203
("self.json_msg.other.invalid_column", StructFieldNotFoundError),
204
("json_msg.other -> invalid_column", SQLSyntaxError),
205
("json_msg -> DATE '2020-09-11'", SQLSyntaxError),
206
],
207
)
208
def test_struct_field_selection_errors(
209
invalid_column: str,
210
error_type: type[Exception],
211
df_struct: pl.DataFrame,
212
) -> None:
213
error_msg = (
214
"invalid json/struct path-extract"
215
if ("->" in invalid_column)
216
else "invalid_column"
217
)
218
with pytest.raises(error_type, match=error_msg):
219
df_struct.sql(f"SELECT {invalid_column} FROM self")
220
221