Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_group_by.py
8406 views
1
from __future__ import annotations
2
3
from datetime import date
4
from pathlib import Path
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import SQLSyntaxError
10
from polars.testing import assert_frame_equal
11
from tests.unit.sql import assert_sql_matches
12
13
14
@pytest.fixture
15
def foods_ipc_path() -> Path:
16
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"
17
18
19
def test_group_by(foods_ipc_path: Path) -> None:
20
lf = pl.scan_ipc(foods_ipc_path)
21
22
ctx = pl.SQLContext(eager=True)
23
ctx.register("foods", lf)
24
25
out = ctx.execute(
26
"""
27
SELECT
28
count(category) as n,
29
category,
30
max(calories) as max_cal,
31
median(calories) as median_cal,
32
min(fats_g) as min_fats
33
FROM foods
34
GROUP BY category
35
HAVING n > 5
36
ORDER BY n, category DESC
37
"""
38
)
39
assert out.to_dict(as_series=False) == {
40
"n": [7, 7, 8],
41
"category": ["vegetables", "fruit", "seafood"],
42
"max_cal": [45, 130, 200],
43
"median_cal": [25.0, 50.0, 145.0],
44
"min_fats": [0.0, 0.0, 1.5],
45
}
46
47
lf = pl.LazyFrame(
48
{
49
"grp": ["a", "b", "c", "c", "b"],
50
"att": ["x", "y", "x", "y", "y"],
51
}
52
)
53
assert ctx.tables() == ["foods"]
54
55
ctx.register("test", lf)
56
assert ctx.tables() == ["foods", "test"]
57
58
out = ctx.execute(
59
"""
60
SELECT
61
grp,
62
COUNT(DISTINCT att) AS n_dist_attr
63
FROM test
64
GROUP BY grp
65
HAVING n_dist_attr > 1
66
"""
67
)
68
assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]}
69
70
71
def test_group_by_all() -> None:
72
df = pl.DataFrame(
73
{
74
"a": ["xx", "yy", "xx", "yy", "xx", "zz"],
75
"b": [1, 2, 3, 4, 5, 6],
76
"c": [99, 99, 66, 66, 66, 66],
77
}
78
)
79
80
# basic group/agg
81
res = df.sql(
82
"""
83
SELECT
84
a,
85
SUM(b),
86
SUM(c),
87
COUNT(*) AS n
88
FROM self
89
GROUP BY ALL
90
ORDER BY ALL
91
"""
92
)
93
expected = pl.DataFrame(
94
{
95
"a": ["xx", "yy", "zz"],
96
"b": [9, 6, 6],
97
"c": [231, 165, 66],
98
"n": [3, 2, 1],
99
}
100
)
101
assert_frame_equal(expected, res, check_dtypes=False)
102
103
# more involved determination of agg/group columns
104
res = df.sql(
105
"""
106
SELECT
107
SUM(b) AS sum_b,
108
SUM(c) AS sum_c,
109
(SUM(b) + SUM(c)) / 2.0 AS sum_bc_over_2, -- nested agg
110
a as grp, --aliased group key
111
FROM self
112
GROUP BY ALL
113
ORDER BY grp
114
"""
115
)
116
expected = pl.DataFrame(
117
{
118
"sum_b": [9, 6, 6],
119
"sum_c": [231, 165, 66],
120
"sum_bc_over_2": [120.0, 85.5, 36.0],
121
"grp": ["xx", "yy", "zz"],
122
}
123
)
124
assert_frame_equal(expected, res.sort(by="grp"))
125
126
127
def test_group_by_all_multi() -> None:
128
dt1 = date(1999, 12, 31)
129
dt2 = date(2028, 7, 5)
130
131
df = pl.DataFrame(
132
{
133
"key": ["xx", "yy", "xx", "yy", "xx", "xx"],
134
"dt": [dt1, dt1, dt1, dt2, dt2, dt2],
135
"value": [10.5, -5.5, 20.5, 8.0, -3.0, 5.0],
136
}
137
)
138
expected = pl.DataFrame(
139
{
140
"dt": [dt1, dt1, dt2, dt2],
141
"key": ["xx", "yy", "xx", "yy"],
142
"sum_value": [31.0, -5.5, 2.0, 8.0],
143
"ninety_nine": [99, 99, 99, 99],
144
},
145
schema_overrides={"ninety_nine": pl.Int16},
146
)
147
148
# the following groupings should all be equivalent
149
for group in (
150
"ALL",
151
"1, 2",
152
"dt, key",
153
):
154
res = df.sql(
155
f"""
156
SELECT dt, key, sum_value, ninety_nine::int2 FROM
157
(
158
SELECT
159
dt,
160
key,
161
SUM(value) AS sum_value,
162
99 AS ninety_nine
163
FROM self
164
GROUP BY {group}
165
ORDER BY dt, key
166
) AS grp
167
"""
168
)
169
assert_frame_equal(expected, res)
170
171
172
def test_group_by_ordinal_position() -> None:
173
df = pl.DataFrame(
174
{
175
"a": ["xx", "yy", "xx", "yy", "xx", "zz"],
176
"b": [1, None, 3, 4, 5, 6],
177
"c": [99, 99, 66, 66, 66, 66],
178
}
179
)
180
expected = pl.LazyFrame(
181
{
182
"c": [66, 99],
183
"total_b": [18, 1],
184
"count_b": [4, 1],
185
"count_star": [4, 2],
186
}
187
)
188
189
with pl.SQLContext(frame=df) as ctx:
190
res1 = ctx.execute(
191
"""
192
SELECT
193
c,
194
SUM(b) AS total_b,
195
COUNT(b) AS count_b,
196
COUNT(*) AS count_star
197
FROM frame
198
GROUP BY 1
199
ORDER BY c
200
"""
201
)
202
assert_frame_equal(res1, expected, check_dtypes=False)
203
204
res2 = ctx.execute(
205
"""
206
WITH "grp" AS (
207
SELECT NULL::date as dt, c, SUM(b) AS total_b
208
FROM frame
209
GROUP BY 2, 1
210
)
211
SELECT c, total_b FROM grp ORDER BY c"""
212
)
213
assert_frame_equal(res2, expected.select(pl.nth(0, 1)))
214
215
216
def test_group_by_errors() -> None:
217
df = pl.DataFrame(
218
{
219
"a": ["xx", "yy", "xx"],
220
"b": [10, 20, 30],
221
"c": [99, 99, 66],
222
}
223
)
224
with pytest.raises(
225
SQLSyntaxError,
226
match=r"negative ordinal values are invalid for GROUP BY; found -99",
227
):
228
df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a")
229
230
with pytest.raises(
231
SQLSyntaxError,
232
match=r"GROUP BY requires a valid expression or positive ordinal; found '!!!'",
233
):
234
df.sql("SELECT a, SUM(b) FROM self GROUP BY a, '!!!'")
235
236
with pytest.raises(
237
SQLSyntaxError,
238
match=r"'a' should participate in the GROUP BY clause or an aggregate function",
239
):
240
df.sql("SELECT a, SUM(b) FROM self GROUP BY b")
241
242
with pytest.raises(
243
SQLSyntaxError,
244
match=r"HAVING clause not valid outside of GROUP BY",
245
):
246
df.sql("SELECT a, COUNT(a) AS n FROM self HAVING n > 1")
247
248
249
def test_group_by_having_aggregate_not_in_select() -> None:
250
"""Test HAVING with aggregate functions not present in SELECT."""
251
df = pl.DataFrame(
252
{"grp": ["a", "a", "a", "b", "b", "c"], "val": [1, 2, 3, 4, 5, 6]}
253
)
254
# COUNT(*) not in SELECT - only group 'a' has 3 rows
255
assert_sql_matches(
256
df,
257
query="SELECT grp FROM self GROUP BY grp HAVING COUNT(*) > 2",
258
compare_with="sqlite",
259
expected={"grp": ["a"]},
260
)
261
262
# SUM not in SELECT
263
assert_sql_matches(
264
df,
265
query="SELECT grp FROM self GROUP BY grp HAVING SUM(val) > 5 ORDER BY grp",
266
compare_with="sqlite",
267
expected={"grp": ["a", "b", "c"]},
268
)
269
270
# AVG not in SELECT
271
assert_sql_matches(
272
df,
273
query="SELECT grp FROM self GROUP BY grp HAVING AVG(val) > 4 ORDER BY grp",
274
compare_with="sqlite",
275
expected={"grp": ["b", "c"]},
276
)
277
278
# MIN/MAX not in SELECT
279
assert_sql_matches(
280
df,
281
query="SELECT grp FROM self GROUP BY grp HAVING MIN(val) >= 4 ORDER BY grp",
282
compare_with="sqlite",
283
expected={"grp": ["b", "c"]},
284
)
285
286
287
def test_group_by_having_aggregate_in_select() -> None:
288
"""Test HAVING properly references an aggregate already computed in SELECT."""
289
df = pl.DataFrame(
290
{"grp": ["a", "a", "a", "b", "b", "c"], "val": [1, 2, 3, 4, 5, 6]}
291
)
292
# COUNT(*) in SELECT and HAVING
293
for count_expr in ("COUNT(*)", "cnt"):
294
assert_sql_matches(
295
df,
296
query=f"SELECT grp, COUNT(*) AS cnt FROM self GROUP BY grp HAVING {count_expr} > 2",
297
compare_with="sqlite",
298
expected={"grp": ["a"], "cnt": [3]},
299
)
300
301
# SUM in SELECT and HAVING
302
for sum_expr in ("total", "SUM(val)"):
303
assert_sql_matches(
304
df,
305
query=f"SELECT grp, SUM(val) AS total FROM self GROUP BY grp HAVING {sum_expr} > 5 ORDER BY grp",
306
compare_with="sqlite",
307
expected={"grp": ["a", "b", "c"], "total": [6, 9, 6]},
308
)
309
310
311
def test_group_by_having_multiple_aggregates() -> None:
312
"""Test HAVING with multiple aggregate conditions."""
313
df = pl.DataFrame(
314
{"grp": ["a", "a", "a", "b", "b", "c"], "val": [1, 2, 3, 4, 5, 6]}
315
)
316
assert_sql_matches(
317
df,
318
query="SELECT grp FROM self GROUP BY grp HAVING COUNT(*) >= 2 AND SUM(val) > 5 ORDER BY grp",
319
compare_with="sqlite",
320
expected={"grp": ["a", "b"]},
321
)
322
assert_sql_matches(
323
df,
324
query="SELECT grp FROM self GROUP BY grp HAVING COUNT(*) = 1 OR SUM(val) >= 9 ORDER BY grp",
325
compare_with="sqlite",
326
expected={"grp": ["b", "c"]},
327
)
328
329
330
def test_group_by_having_compound_expressions() -> None:
331
"""Test HAVING with compound expressions involving aggregates."""
332
df = pl.DataFrame(
333
{"grp": ["a", "a", "c", "b", "b"], "val": [10, 20, 100, 5, 15]},
334
)
335
assert_sql_matches(
336
df,
337
query="SELECT grp FROM self GROUP BY grp HAVING SUM(val) / COUNT(*) > 10 ORDER BY grp",
338
compare_with="sqlite",
339
expected={"grp": ["a", "c"]},
340
)
341
assert_sql_matches(
342
df,
343
query="SELECT grp FROM self GROUP BY grp HAVING MAX(val) - MIN(val) > 5 ORDER BY grp DESC",
344
compare_with="sqlite",
345
expected={"grp": ["b", "a"]},
346
)
347
for sum_expr, count_expr in (
348
("SUM(val)", "COUNT(*)"),
349
("total", "COUNT(*)"),
350
("SUM(val)", "n"),
351
("total", "n"),
352
):
353
assert_sql_matches(
354
df,
355
query=f"""
356
SELECT grp, SUM(val) AS total, COUNT(*) AS n
357
FROM self
358
GROUP BY grp
359
HAVING {sum_expr} / {count_expr} > 10 ORDER BY grp
360
""",
361
compare_with="sqlite",
362
expected={
363
"grp": ["a", "c"],
364
"total": [30, 100],
365
"n": [2, 1],
366
},
367
)
368
369
370
def test_group_by_having_with_nulls() -> None:
371
"""Test HAVING behaviour with NULL values."""
372
df = pl.DataFrame(
373
{"grp": ["a", "b", "a", "b", "c"], "val": [None, None, 1, None, 5]}
374
)
375
# COUNT(*) counts all rows, including NULLs...
376
assert_sql_matches(
377
df,
378
query="SELECT grp FROM self GROUP BY grp HAVING COUNT(*) > 1 ORDER BY grp",
379
compare_with="sqlite",
380
expected={"grp": ["a", "b"]},
381
)
382
383
# ...whereas COUNT(col) excludes NULLs
384
assert_sql_matches(
385
df,
386
query="SELECT grp FROM self GROUP BY grp HAVING COUNT(val) > 0 ORDER BY grp",
387
compare_with="sqlite",
388
expected={"grp": ["a", "c"]},
389
)
390
391
392
@pytest.mark.parametrize(
393
("having_clause", "expected"),
394
[
395
# basic count conditions
396
("COUNT(*) > 2", [1]),
397
("COUNT(*) >= 2 AND COUNT(*) <= 3", [1, 2]),
398
("(COUNT(*) > 1)", [1, 2]),
399
("NOT COUNT(*) < 2", [1, 2]),
400
# range / membership
401
("COUNT(*) BETWEEN 2 AND 3", [1, 2]),
402
("COUNT(*) NOT BETWEEN 1 AND 2", [1]),
403
("COUNT(*) IN (1, 3)", [1, 3]),
404
("COUNT(*) NOT IN (1, 2)", [1]),
405
# conditional
406
("CASE WHEN COUNT(*) > 2 THEN 1 ELSE 0 END = 1", [1]),
407
],
408
)
409
def test_group_by_having_misc_01(
410
having_clause: str,
411
expected: list[int],
412
) -> None:
413
df = pl.DataFrame({"a": [1, 1, 1, 2, 2, 3]})
414
assert_sql_matches(
415
df,
416
query=f"SELECT a FROM self GROUP BY a HAVING {having_clause} ORDER BY a",
417
compare_with="sqlite",
418
expected={"a": expected},
419
)
420
421
422
@pytest.mark.parametrize(
423
("having_clause", "expected"),
424
[
425
("SUM(b) > 50", [1, 3]),
426
("AVG(b) > 15", [1, 3]),
427
("ABS(SUM(b)) > 50", [1, 3]),
428
("ROUND(ABS(AVG(b))) > 15", [1, 3]),
429
("ABS(SUM(b)) + ABS(AVG(b)) > 100", [3]),
430
("CASE WHEN SUM(b) < 10 THEN 0 ELSE SUM(b) END > 50", [1, 3]),
431
],
432
)
433
def test_group_by_having_misc_02(
434
having_clause: str,
435
expected: list[int],
436
) -> None:
437
df = pl.DataFrame({"a": [1, 1, 1, 2, 2, 3], "b": [10, 20, 30, 5, 15, 100]})
438
assert_sql_matches(
439
df,
440
query=f"SELECT a FROM self GROUP BY a HAVING {having_clause} ORDER BY a",
441
compare_with="sqlite",
442
expected={"a": expected},
443
)
444
445
446
@pytest.mark.parametrize(
447
("having_clause", "expected"),
448
[
449
("MAX(b) IS NULL", [1]),
450
("MAX(b) IS NOT NULL", [2]),
451
],
452
)
453
def test_group_by_having_misc_03(
454
having_clause: str,
455
expected: list[int],
456
) -> None:
457
df = pl.DataFrame({"a": [1, 1, 2], "b": [None, None, 5]})
458
assert_sql_matches(
459
df,
460
query=f"SELECT a FROM self GROUP BY a HAVING {having_clause}",
461
compare_with="sqlite",
462
expected={"a": expected},
463
)
464
465
466
def test_group_by_output_struct() -> None:
467
df = pl.DataFrame({"g": [1], "x": [2], "y": [3]})
468
out = df.group_by("g").agg(pl.struct(pl.col.x.min(), pl.col.y.sum()))
469
assert out.rows() == [(1, {"x": 2, "y": 3})]
470
471
472
@pytest.mark.parametrize(
473
"maintain_order",
474
[False, True],
475
)
476
def test_group_by_list_cat_24049(maintain_order: bool) -> None:
477
df = pl.DataFrame(
478
{
479
"x": [["a"], ["b", "c"], ["a"], ["a"], ["d"], ["b", "c"]],
480
"y": [1, 2, 3, 4, 5, 10],
481
},
482
schema={"x": pl.List(pl.Categorical), "y": pl.Int32},
483
)
484
485
expected = pl.DataFrame(
486
{"x": [["a"], ["b", "c"], ["d"]], "y": [8, 12, 5]},
487
schema={"x": pl.List(pl.Categorical), "y": pl.Int32},
488
)
489
assert_frame_equal(
490
df.group_by("x", maintain_order=maintain_order).agg(pl.col.y.sum()),
491
expected,
492
check_row_order=maintain_order,
493
)
494
495
496
@pytest.mark.parametrize(
497
"maintain_order",
498
[False, True],
499
)
500
def test_group_by_struct_cat_24049(maintain_order: bool) -> None:
501
a = {"k1": "a2", "k2": "a2"}
502
b = {"k1": "b2", "k2": "b2"}
503
c = {"k1": "c2", "k2": "c2"}
504
s = pl.Struct({"k1": pl.Categorical, "k2": pl.Categorical})
505
df = pl.DataFrame(
506
{
507
"x": [a, b, a, a, c, b],
508
"y": [1, 2, 3, 4, 5, 10],
509
},
510
schema={"x": s, "y": pl.Int32},
511
)
512
513
expected = pl.DataFrame(
514
{"x": [a, b, c], "y": [8, 12, 5]},
515
schema={"x": s, "y": pl.Int32},
516
)
517
assert_frame_equal(
518
df.group_by("x", maintain_order=maintain_order).agg(pl.col.y.sum()),
519
expected,
520
check_row_order=maintain_order,
521
)
522
523
524
def test_group_by_aggregate_name_is_group_key() -> None:
525
"""Unaliased aggregation with a column that's also used in the GROUP BY key."""
526
df = pl.DataFrame({"c0": [1, 2]})
527
528
# 'COUNT(col)' where 'col' is also part of the the group key
529
for query in (
530
"SELECT COUNT(c0) FROM self GROUP BY c0",
531
"SELECT COUNT(c0) AS c0 FROM self GROUP BY c0",
532
):
533
assert_sql_matches(
534
df,
535
query=query,
536
compare_with="sqlite",
537
check_column_names=False,
538
expected={"c0": [1, 1]},
539
)
540
541
# Same condition with a table prefix (and a different aggfunc)
542
query = "SELECT SUM(self.c0) FROM self GROUP BY self.c0"
543
assert_sql_matches(
544
df,
545
query=query,
546
compare_with="sqlite",
547
check_row_order=False,
548
check_column_names=False,
549
expected={"c0": [1, 2]},
550
)
551
552
553
@pytest.mark.parametrize(
554
"query",
555
[
556
# GROUP BY referencing SELECT alias for arithmetic expression
557
"SELECT COUNT(*) AS n, value / 10 AS bucket FROM self GROUP BY bucket ORDER BY bucket",
558
# Multiple aliased expressions in GROUP BY
559
"SELECT COUNT(*) AS n, value / 10 AS tens, value % 3 AS rem FROM self GROUP BY tens, rem ORDER BY tens, rem",
560
# GROUP BY alias with additional aggregation
561
"SELECT SUM(id) AS total, value / 20 AS grp FROM self GROUP BY grp ORDER BY grp",
562
# GROUP BY ordinal position with aliased column
563
"SELECT value / 10 AS bucket, COUNT(*) AS n FROM self GROUP BY 1 ORDER BY 1",
564
# GROUP BY ordinal with multiple aliased columns
565
"SELECT id % 2 AS parity, value / 10 AS tens, SUM(id) AS total FROM self GROUP BY 1, 2 ORDER BY 1, 2",
566
],
567
)
568
def test_group_by_select_alias(query: str) -> None:
569
"""Test GROUP BY can reference SELECT aliases for computed expressions."""
570
df = pl.DataFrame(
571
{
572
"id": [1, 2, 3, 4, 5],
573
"value": [10, 20, 30, 40, 50],
574
}
575
)
576
assert_sql_matches(df, query=query, compare_with="sqlite")
577
578