Path: blob/main/py-polars/tests/unit/sql/test_group_by.py
8406 views
from __future__ import annotations12from datetime import date3from pathlib import Path45import pytest67import polars as pl8from polars.exceptions import SQLSyntaxError9from polars.testing import assert_frame_equal10from tests.unit.sql import assert_sql_matches111213@pytest.fixture14def foods_ipc_path() -> Path:15return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"161718def test_group_by(foods_ipc_path: Path) -> None:19lf = pl.scan_ipc(foods_ipc_path)2021ctx = pl.SQLContext(eager=True)22ctx.register("foods", lf)2324out = ctx.execute(25"""26SELECT27count(category) as n,28category,29max(calories) as max_cal,30median(calories) as median_cal,31min(fats_g) as min_fats32FROM foods33GROUP BY category34HAVING n > 535ORDER BY n, category DESC36"""37)38assert out.to_dict(as_series=False) == {39"n": [7, 7, 8],40"category": ["vegetables", "fruit", "seafood"],41"max_cal": [45, 130, 200],42"median_cal": [25.0, 50.0, 145.0],43"min_fats": [0.0, 0.0, 1.5],44}4546lf = pl.LazyFrame(47{48"grp": ["a", "b", "c", "c", "b"],49"att": ["x", "y", "x", "y", "y"],50}51)52assert ctx.tables() == ["foods"]5354ctx.register("test", lf)55assert ctx.tables() == ["foods", "test"]5657out = ctx.execute(58"""59SELECT60grp,61COUNT(DISTINCT att) AS n_dist_attr62FROM test63GROUP BY grp64HAVING n_dist_attr > 165"""66)67assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]}686970def test_group_by_all() -> None:71df = pl.DataFrame(72{73"a": ["xx", "yy", "xx", "yy", "xx", "zz"],74"b": [1, 2, 3, 4, 5, 6],75"c": [99, 99, 66, 66, 66, 66],76}77)7879# basic group/agg80res = df.sql(81"""82SELECT83a,84SUM(b),85SUM(c),86COUNT(*) AS n87FROM self88GROUP BY ALL89ORDER BY ALL90"""91)92expected = pl.DataFrame(93{94"a": ["xx", "yy", "zz"],95"b": [9, 6, 6],96"c": [231, 165, 66],97"n": [3, 2, 1],98}99)100assert_frame_equal(expected, res, check_dtypes=False)101102# more involved determination of agg/group columns103res = df.sql(104"""105SELECT106SUM(b) AS sum_b,107SUM(c) AS sum_c,108(SUM(b) + SUM(c)) / 2.0 AS sum_bc_over_2, -- nested agg109a as grp, --aliased group key110FROM self111GROUP BY ALL112ORDER BY grp113"""114)115expected = pl.DataFrame(116{117"sum_b": [9, 6, 6],118"sum_c": [231, 165, 66],119"sum_bc_over_2": [120.0, 85.5, 36.0],120"grp": ["xx", "yy", "zz"],121}122)123assert_frame_equal(expected, res.sort(by="grp"))124125126def test_group_by_all_multi() -> None:127dt1 = date(1999, 12, 31)128dt2 = date(2028, 7, 5)129130df = pl.DataFrame(131{132"key": ["xx", "yy", "xx", "yy", "xx", "xx"],133"dt": [dt1, dt1, dt1, dt2, dt2, dt2],134"value": [10.5, -5.5, 20.5, 8.0, -3.0, 5.0],135}136)137expected = pl.DataFrame(138{139"dt": [dt1, dt1, dt2, dt2],140"key": ["xx", "yy", "xx", "yy"],141"sum_value": [31.0, -5.5, 2.0, 8.0],142"ninety_nine": [99, 99, 99, 99],143},144schema_overrides={"ninety_nine": pl.Int16},145)146147# the following groupings should all be equivalent148for group in (149"ALL",150"1, 2",151"dt, key",152):153res = df.sql(154f"""155SELECT dt, key, sum_value, ninety_nine::int2 FROM156(157SELECT158dt,159key,160SUM(value) AS sum_value,16199 AS ninety_nine162FROM self163GROUP BY {group}164ORDER BY dt, key165) AS grp166"""167)168assert_frame_equal(expected, res)169170171def test_group_by_ordinal_position() -> None:172df = pl.DataFrame(173{174"a": ["xx", "yy", "xx", "yy", "xx", "zz"],175"b": [1, None, 3, 4, 5, 6],176"c": [99, 99, 66, 66, 66, 66],177}178)179expected = pl.LazyFrame(180{181"c": [66, 99],182"total_b": [18, 1],183"count_b": [4, 1],184"count_star": [4, 2],185}186)187188with pl.SQLContext(frame=df) as ctx:189res1 = ctx.execute(190"""191SELECT192c,193SUM(b) AS total_b,194COUNT(b) AS count_b,195COUNT(*) AS count_star196FROM frame197GROUP BY 1198ORDER BY c199"""200)201assert_frame_equal(res1, expected, check_dtypes=False)202203res2 = ctx.execute(204"""205WITH "grp" AS (206SELECT NULL::date as dt, c, SUM(b) AS total_b207FROM frame208GROUP BY 2, 1209)210SELECT c, total_b FROM grp ORDER BY c"""211)212assert_frame_equal(res2, expected.select(pl.nth(0, 1)))213214215def test_group_by_errors() -> None:216df = pl.DataFrame(217{218"a": ["xx", "yy", "xx"],219"b": [10, 20, 30],220"c": [99, 99, 66],221}222)223with pytest.raises(224SQLSyntaxError,225match=r"negative ordinal values are invalid for GROUP BY; found -99",226):227df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a")228229with pytest.raises(230SQLSyntaxError,231match=r"GROUP BY requires a valid expression or positive ordinal; found '!!!'",232):233df.sql("SELECT a, SUM(b) FROM self GROUP BY a, '!!!'")234235with pytest.raises(236SQLSyntaxError,237match=r"'a' should participate in the GROUP BY clause or an aggregate function",238):239df.sql("SELECT a, SUM(b) FROM self GROUP BY b")240241with pytest.raises(242SQLSyntaxError,243match=r"HAVING clause not valid outside of GROUP BY",244):245df.sql("SELECT a, COUNT(a) AS n FROM self HAVING n > 1")246247248def test_group_by_having_aggregate_not_in_select() -> None:249"""Test HAVING with aggregate functions not present in SELECT."""250df = pl.DataFrame(251{"grp": ["a", "a", "a", "b", "b", "c"], "val": [1, 2, 3, 4, 5, 6]}252)253# COUNT(*) not in SELECT - only group 'a' has 3 rows254assert_sql_matches(255df,256query="SELECT grp FROM self GROUP BY grp HAVING COUNT(*) > 2",257compare_with="sqlite",258expected={"grp": ["a"]},259)260261# SUM not in SELECT262assert_sql_matches(263df,264query="SELECT grp FROM self GROUP BY grp HAVING SUM(val) > 5 ORDER BY grp",265compare_with="sqlite",266expected={"grp": ["a", "b", "c"]},267)268269# AVG not in SELECT270assert_sql_matches(271df,272query="SELECT grp FROM self GROUP BY grp HAVING AVG(val) > 4 ORDER BY grp",273compare_with="sqlite",274expected={"grp": ["b", "c"]},275)276277# MIN/MAX not in SELECT278assert_sql_matches(279df,280query="SELECT grp FROM self GROUP BY grp HAVING MIN(val) >= 4 ORDER BY grp",281compare_with="sqlite",282expected={"grp": ["b", "c"]},283)284285286def test_group_by_having_aggregate_in_select() -> None:287"""Test HAVING properly references an aggregate already computed in SELECT."""288df = pl.DataFrame(289{"grp": ["a", "a", "a", "b", "b", "c"], "val": [1, 2, 3, 4, 5, 6]}290)291# COUNT(*) in SELECT and HAVING292for count_expr in ("COUNT(*)", "cnt"):293assert_sql_matches(294df,295query=f"SELECT grp, COUNT(*) AS cnt FROM self GROUP BY grp HAVING {count_expr} > 2",296compare_with="sqlite",297expected={"grp": ["a"], "cnt": [3]},298)299300# SUM in SELECT and HAVING301for sum_expr in ("total", "SUM(val)"):302assert_sql_matches(303df,304query=f"SELECT grp, SUM(val) AS total FROM self GROUP BY grp HAVING {sum_expr} > 5 ORDER BY grp",305compare_with="sqlite",306expected={"grp": ["a", "b", "c"], "total": [6, 9, 6]},307)308309310def test_group_by_having_multiple_aggregates() -> None:311"""Test HAVING with multiple aggregate conditions."""312df = pl.DataFrame(313{"grp": ["a", "a", "a", "b", "b", "c"], "val": [1, 2, 3, 4, 5, 6]}314)315assert_sql_matches(316df,317query="SELECT grp FROM self GROUP BY grp HAVING COUNT(*) >= 2 AND SUM(val) > 5 ORDER BY grp",318compare_with="sqlite",319expected={"grp": ["a", "b"]},320)321assert_sql_matches(322df,323query="SELECT grp FROM self GROUP BY grp HAVING COUNT(*) = 1 OR SUM(val) >= 9 ORDER BY grp",324compare_with="sqlite",325expected={"grp": ["b", "c"]},326)327328329def test_group_by_having_compound_expressions() -> None:330"""Test HAVING with compound expressions involving aggregates."""331df = pl.DataFrame(332{"grp": ["a", "a", "c", "b", "b"], "val": [10, 20, 100, 5, 15]},333)334assert_sql_matches(335df,336query="SELECT grp FROM self GROUP BY grp HAVING SUM(val) / COUNT(*) > 10 ORDER BY grp",337compare_with="sqlite",338expected={"grp": ["a", "c"]},339)340assert_sql_matches(341df,342query="SELECT grp FROM self GROUP BY grp HAVING MAX(val) - MIN(val) > 5 ORDER BY grp DESC",343compare_with="sqlite",344expected={"grp": ["b", "a"]},345)346for sum_expr, count_expr in (347("SUM(val)", "COUNT(*)"),348("total", "COUNT(*)"),349("SUM(val)", "n"),350("total", "n"),351):352assert_sql_matches(353df,354query=f"""355SELECT grp, SUM(val) AS total, COUNT(*) AS n356FROM self357GROUP BY grp358HAVING {sum_expr} / {count_expr} > 10 ORDER BY grp359""",360compare_with="sqlite",361expected={362"grp": ["a", "c"],363"total": [30, 100],364"n": [2, 1],365},366)367368369def test_group_by_having_with_nulls() -> None:370"""Test HAVING behaviour with NULL values."""371df = pl.DataFrame(372{"grp": ["a", "b", "a", "b", "c"], "val": [None, None, 1, None, 5]}373)374# COUNT(*) counts all rows, including NULLs...375assert_sql_matches(376df,377query="SELECT grp FROM self GROUP BY grp HAVING COUNT(*) > 1 ORDER BY grp",378compare_with="sqlite",379expected={"grp": ["a", "b"]},380)381382# ...whereas COUNT(col) excludes NULLs383assert_sql_matches(384df,385query="SELECT grp FROM self GROUP BY grp HAVING COUNT(val) > 0 ORDER BY grp",386compare_with="sqlite",387expected={"grp": ["a", "c"]},388)389390391@pytest.mark.parametrize(392("having_clause", "expected"),393[394# basic count conditions395("COUNT(*) > 2", [1]),396("COUNT(*) >= 2 AND COUNT(*) <= 3", [1, 2]),397("(COUNT(*) > 1)", [1, 2]),398("NOT COUNT(*) < 2", [1, 2]),399# range / membership400("COUNT(*) BETWEEN 2 AND 3", [1, 2]),401("COUNT(*) NOT BETWEEN 1 AND 2", [1]),402("COUNT(*) IN (1, 3)", [1, 3]),403("COUNT(*) NOT IN (1, 2)", [1]),404# conditional405("CASE WHEN COUNT(*) > 2 THEN 1 ELSE 0 END = 1", [1]),406],407)408def test_group_by_having_misc_01(409having_clause: str,410expected: list[int],411) -> None:412df = pl.DataFrame({"a": [1, 1, 1, 2, 2, 3]})413assert_sql_matches(414df,415query=f"SELECT a FROM self GROUP BY a HAVING {having_clause} ORDER BY a",416compare_with="sqlite",417expected={"a": expected},418)419420421@pytest.mark.parametrize(422("having_clause", "expected"),423[424("SUM(b) > 50", [1, 3]),425("AVG(b) > 15", [1, 3]),426("ABS(SUM(b)) > 50", [1, 3]),427("ROUND(ABS(AVG(b))) > 15", [1, 3]),428("ABS(SUM(b)) + ABS(AVG(b)) > 100", [3]),429("CASE WHEN SUM(b) < 10 THEN 0 ELSE SUM(b) END > 50", [1, 3]),430],431)432def test_group_by_having_misc_02(433having_clause: str,434expected: list[int],435) -> None:436df = pl.DataFrame({"a": [1, 1, 1, 2, 2, 3], "b": [10, 20, 30, 5, 15, 100]})437assert_sql_matches(438df,439query=f"SELECT a FROM self GROUP BY a HAVING {having_clause} ORDER BY a",440compare_with="sqlite",441expected={"a": expected},442)443444445@pytest.mark.parametrize(446("having_clause", "expected"),447[448("MAX(b) IS NULL", [1]),449("MAX(b) IS NOT NULL", [2]),450],451)452def test_group_by_having_misc_03(453having_clause: str,454expected: list[int],455) -> None:456df = pl.DataFrame({"a": [1, 1, 2], "b": [None, None, 5]})457assert_sql_matches(458df,459query=f"SELECT a FROM self GROUP BY a HAVING {having_clause}",460compare_with="sqlite",461expected={"a": expected},462)463464465def test_group_by_output_struct() -> None:466df = pl.DataFrame({"g": [1], "x": [2], "y": [3]})467out = df.group_by("g").agg(pl.struct(pl.col.x.min(), pl.col.y.sum()))468assert out.rows() == [(1, {"x": 2, "y": 3})]469470471@pytest.mark.parametrize(472"maintain_order",473[False, True],474)475def test_group_by_list_cat_24049(maintain_order: bool) -> None:476df = pl.DataFrame(477{478"x": [["a"], ["b", "c"], ["a"], ["a"], ["d"], ["b", "c"]],479"y": [1, 2, 3, 4, 5, 10],480},481schema={"x": pl.List(pl.Categorical), "y": pl.Int32},482)483484expected = pl.DataFrame(485{"x": [["a"], ["b", "c"], ["d"]], "y": [8, 12, 5]},486schema={"x": pl.List(pl.Categorical), "y": pl.Int32},487)488assert_frame_equal(489df.group_by("x", maintain_order=maintain_order).agg(pl.col.y.sum()),490expected,491check_row_order=maintain_order,492)493494495@pytest.mark.parametrize(496"maintain_order",497[False, True],498)499def test_group_by_struct_cat_24049(maintain_order: bool) -> None:500a = {"k1": "a2", "k2": "a2"}501b = {"k1": "b2", "k2": "b2"}502c = {"k1": "c2", "k2": "c2"}503s = pl.Struct({"k1": pl.Categorical, "k2": pl.Categorical})504df = pl.DataFrame(505{506"x": [a, b, a, a, c, b],507"y": [1, 2, 3, 4, 5, 10],508},509schema={"x": s, "y": pl.Int32},510)511512expected = pl.DataFrame(513{"x": [a, b, c], "y": [8, 12, 5]},514schema={"x": s, "y": pl.Int32},515)516assert_frame_equal(517df.group_by("x", maintain_order=maintain_order).agg(pl.col.y.sum()),518expected,519check_row_order=maintain_order,520)521522523def test_group_by_aggregate_name_is_group_key() -> None:524"""Unaliased aggregation with a column that's also used in the GROUP BY key."""525df = pl.DataFrame({"c0": [1, 2]})526527# 'COUNT(col)' where 'col' is also part of the the group key528for query in (529"SELECT COUNT(c0) FROM self GROUP BY c0",530"SELECT COUNT(c0) AS c0 FROM self GROUP BY c0",531):532assert_sql_matches(533df,534query=query,535compare_with="sqlite",536check_column_names=False,537expected={"c0": [1, 1]},538)539540# Same condition with a table prefix (and a different aggfunc)541query = "SELECT SUM(self.c0) FROM self GROUP BY self.c0"542assert_sql_matches(543df,544query=query,545compare_with="sqlite",546check_row_order=False,547check_column_names=False,548expected={"c0": [1, 2]},549)550551552@pytest.mark.parametrize(553"query",554[555# GROUP BY referencing SELECT alias for arithmetic expression556"SELECT COUNT(*) AS n, value / 10 AS bucket FROM self GROUP BY bucket ORDER BY bucket",557# Multiple aliased expressions in GROUP BY558"SELECT COUNT(*) AS n, value / 10 AS tens, value % 3 AS rem FROM self GROUP BY tens, rem ORDER BY tens, rem",559# GROUP BY alias with additional aggregation560"SELECT SUM(id) AS total, value / 20 AS grp FROM self GROUP BY grp ORDER BY grp",561# GROUP BY ordinal position with aliased column562"SELECT value / 10 AS bucket, COUNT(*) AS n FROM self GROUP BY 1 ORDER BY 1",563# GROUP BY ordinal with multiple aliased columns564"SELECT id % 2 AS parity, value / 10 AS tens, SUM(id) AS total FROM self GROUP BY 1, 2 ORDER BY 1, 2",565],566)567def test_group_by_select_alias(query: str) -> None:568"""Test GROUP BY can reference SELECT aliases for computed expressions."""569df = pl.DataFrame(570{571"id": [1, 2, 3, 4, 5],572"value": [10, 20, 30, 40, 50],573}574)575assert_sql_matches(df, query=query, compare_with="sqlite")576577578