Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_group_by.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date
4
from pathlib import Path
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import SQLSyntaxError
10
from polars.testing import assert_frame_equal
11
12
13
@pytest.fixture
14
def foods_ipc_path() -> Path:
15
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"
16
17
18
def test_group_by(foods_ipc_path: Path) -> None:
19
lf = pl.scan_ipc(foods_ipc_path)
20
21
ctx = pl.SQLContext(eager=True)
22
ctx.register("foods", lf)
23
24
out = ctx.execute(
25
"""
26
SELECT
27
count(category) as n,
28
category,
29
max(calories) as max_cal,
30
median(calories) as median_cal,
31
min(fats_g) as min_fats
32
FROM foods
33
GROUP BY category
34
HAVING n > 5
35
ORDER BY n, category DESC
36
"""
37
)
38
assert out.to_dict(as_series=False) == {
39
"n": [7, 7, 8],
40
"category": ["vegetables", "fruit", "seafood"],
41
"max_cal": [45, 130, 200],
42
"median_cal": [25.0, 50.0, 145.0],
43
"min_fats": [0.0, 0.0, 1.5],
44
}
45
46
lf = pl.LazyFrame(
47
{
48
"grp": ["a", "b", "c", "c", "b"],
49
"att": ["x", "y", "x", "y", "y"],
50
}
51
)
52
assert ctx.tables() == ["foods"]
53
54
ctx.register("test", lf)
55
assert ctx.tables() == ["foods", "test"]
56
57
out = ctx.execute(
58
"""
59
SELECT
60
grp,
61
COUNT(DISTINCT att) AS n_dist_attr
62
FROM test
63
GROUP BY grp
64
HAVING n_dist_attr > 1
65
"""
66
)
67
assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]}
68
69
70
def test_group_by_all() -> None:
71
df = pl.DataFrame(
72
{
73
"a": ["xx", "yy", "xx", "yy", "xx", "zz"],
74
"b": [1, 2, 3, 4, 5, 6],
75
"c": [99, 99, 66, 66, 66, 66],
76
}
77
)
78
79
# basic group/agg
80
res = df.sql(
81
"""
82
SELECT
83
a,
84
SUM(b),
85
SUM(c),
86
COUNT(*) AS n
87
FROM self
88
GROUP BY ALL
89
ORDER BY ALL
90
"""
91
)
92
expected = pl.DataFrame(
93
{
94
"a": ["xx", "yy", "zz"],
95
"b": [9, 6, 6],
96
"c": [231, 165, 66],
97
"n": [3, 2, 1],
98
}
99
)
100
assert_frame_equal(expected, res, check_dtypes=False)
101
102
# more involved determination of agg/group columns
103
res = df.sql(
104
"""
105
SELECT
106
SUM(b) AS sum_b,
107
SUM(c) AS sum_c,
108
(SUM(b) + SUM(c)) / 2.0 AS sum_bc_over_2, -- nested agg
109
a as grp, --aliased group key
110
FROM self
111
GROUP BY ALL
112
ORDER BY grp
113
"""
114
)
115
expected = pl.DataFrame(
116
{
117
"sum_b": [9, 6, 6],
118
"sum_c": [231, 165, 66],
119
"sum_bc_over_2": [120.0, 85.5, 36.0],
120
"grp": ["xx", "yy", "zz"],
121
}
122
)
123
assert_frame_equal(expected, res.sort(by="grp"))
124
125
126
def test_group_by_all_multi() -> None:
127
dt1 = date(1999, 12, 31)
128
dt2 = date(2028, 7, 5)
129
130
df = pl.DataFrame(
131
{
132
"key": ["xx", "yy", "xx", "yy", "xx", "xx"],
133
"dt": [dt1, dt1, dt1, dt2, dt2, dt2],
134
"value": [10.5, -5.5, 20.5, 8.0, -3.0, 5.0],
135
}
136
)
137
expected = pl.DataFrame(
138
{
139
"dt": [dt1, dt1, dt2, dt2],
140
"key": ["xx", "yy", "xx", "yy"],
141
"sum_value": [31.0, -5.5, 2.0, 8.0],
142
"ninety_nine": [99, 99, 99, 99],
143
},
144
schema_overrides={"ninety_nine": pl.Int16},
145
)
146
147
# the following groupings should all be equivalent
148
for group in (
149
"ALL",
150
"1, 2",
151
"dt, key",
152
):
153
res = df.sql(
154
f"""
155
SELECT dt, key, sum_value, ninety_nine::int2 FROM
156
(
157
SELECT
158
dt,
159
key,
160
SUM(value) AS sum_value,
161
99 AS ninety_nine
162
FROM self
163
GROUP BY {group}
164
ORDER BY dt, key
165
) AS grp
166
"""
167
)
168
assert_frame_equal(expected, res)
169
170
171
def test_group_by_ordinal_position() -> None:
172
df = pl.DataFrame(
173
{
174
"a": ["xx", "yy", "xx", "yy", "xx", "zz"],
175
"b": [1, None, 3, 4, 5, 6],
176
"c": [99, 99, 66, 66, 66, 66],
177
}
178
)
179
expected = pl.LazyFrame(
180
{
181
"c": [66, 99],
182
"total_b": [18, 1],
183
"count_b": [4, 1],
184
"count_star": [4, 2],
185
}
186
)
187
188
with pl.SQLContext(frame=df) as ctx:
189
res1 = ctx.execute(
190
"""
191
SELECT
192
c,
193
SUM(b) AS total_b,
194
COUNT(b) AS count_b,
195
COUNT(*) AS count_star
196
FROM frame
197
GROUP BY 1
198
ORDER BY c
199
"""
200
)
201
assert_frame_equal(res1, expected, check_dtypes=False)
202
203
res2 = ctx.execute(
204
"""
205
WITH "grp" AS (
206
SELECT NULL::date as dt, c, SUM(b) AS total_b
207
FROM frame
208
GROUP BY 2, 1
209
)
210
SELECT c, total_b FROM grp ORDER BY c"""
211
)
212
assert_frame_equal(res2, expected.select(pl.nth(0, 1)))
213
214
215
def test_group_by_errors() -> None:
216
df = pl.DataFrame(
217
{
218
"a": ["xx", "yy", "xx"],
219
"b": [10, 20, 30],
220
"c": [99, 99, 66],
221
}
222
)
223
224
with pytest.raises(
225
SQLSyntaxError,
226
match=r"negative ordinal values are invalid for GROUP BY; found -99",
227
):
228
df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a")
229
230
with pytest.raises(
231
SQLSyntaxError,
232
match=r"GROUP BY requires a valid expression or positive ordinal; found '!!!'",
233
):
234
df.sql("SELECT a, SUM(b) FROM self GROUP BY a, '!!!'")
235
236
with pytest.raises(
237
SQLSyntaxError,
238
match=r"'a' should participate in the GROUP BY clause or an aggregate function",
239
):
240
df.sql("SELECT a, SUM(b) FROM self GROUP BY b")
241
242
with pytest.raises(
243
SQLSyntaxError,
244
match=r"HAVING clause not valid outside of GROUP BY",
245
):
246
df.sql("SELECT a, COUNT(a) AS n FROM self HAVING n > 1")
247
248
249
def test_group_by_output_struct() -> None:
250
df = pl.DataFrame({"g": [1], "x": [2], "y": [3]})
251
out = df.group_by("g").agg(pl.struct(pl.col.x.min(), pl.col.y.sum()))
252
assert out.rows() == [(1, {"x": 2, "y": 3})]
253
254
255
@pytest.mark.parametrize(
256
"maintain_order",
257
[False, True],
258
)
259
def test_group_by_list_cat_24049(maintain_order: bool) -> None:
260
df = pl.DataFrame(
261
{
262
"x": [["a"], ["b", "c"], ["a"], ["a"], ["d"], ["b", "c"]],
263
"y": [1, 2, 3, 4, 5, 10],
264
},
265
schema={"x": pl.List(pl.Categorical), "y": pl.Int32},
266
)
267
268
expected = pl.DataFrame(
269
{"x": [["a"], ["b", "c"], ["d"]], "y": [8, 12, 5]},
270
schema={"x": pl.List(pl.Categorical), "y": pl.Int32},
271
)
272
assert_frame_equal(
273
df.group_by("x", maintain_order=maintain_order).agg(pl.col.y.sum()),
274
expected,
275
check_row_order=maintain_order,
276
)
277
278
279
@pytest.mark.parametrize(
280
"maintain_order",
281
[False, True],
282
)
283
def test_group_by_struct_cat_24049(maintain_order: bool) -> None:
284
a = {"k1": "a2", "k2": "a2"}
285
b = {"k1": "b2", "k2": "b2"}
286
c = {"k1": "c2", "k2": "c2"}
287
s = pl.Struct({"k1": pl.Categorical, "k2": pl.Categorical})
288
df = pl.DataFrame(
289
{
290
"x": [a, b, a, a, c, b],
291
"y": [1, 2, 3, 4, 5, 10],
292
},
293
schema={"x": s, "y": pl.Int32},
294
)
295
296
expected = pl.DataFrame(
297
{"x": [a, b, c], "y": [8, 12, 5]},
298
schema={"x": s, "y": pl.Int32},
299
)
300
assert_frame_equal(
301
df.group_by("x", maintain_order=maintain_order).agg(pl.col.y.sum()),
302
expected,
303
check_row_order=maintain_order,
304
)
305
306