Path: blob/main/py-polars/tests/unit/sql/test_group_by.py
6939 views
from __future__ import annotations12from datetime import date3from pathlib import Path45import pytest67import polars as pl8from polars.exceptions import SQLSyntaxError9from polars.testing import assert_frame_equal101112@pytest.fixture13def foods_ipc_path() -> Path:14return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"151617def test_group_by(foods_ipc_path: Path) -> None:18lf = pl.scan_ipc(foods_ipc_path)1920ctx = pl.SQLContext(eager=True)21ctx.register("foods", lf)2223out = ctx.execute(24"""25SELECT26count(category) as n,27category,28max(calories) as max_cal,29median(calories) as median_cal,30min(fats_g) as min_fats31FROM foods32GROUP BY category33HAVING n > 534ORDER BY n, category DESC35"""36)37assert out.to_dict(as_series=False) == {38"n": [7, 7, 8],39"category": ["vegetables", "fruit", "seafood"],40"max_cal": [45, 130, 200],41"median_cal": [25.0, 50.0, 145.0],42"min_fats": [0.0, 0.0, 1.5],43}4445lf = pl.LazyFrame(46{47"grp": ["a", "b", "c", "c", "b"],48"att": ["x", "y", "x", "y", "y"],49}50)51assert ctx.tables() == ["foods"]5253ctx.register("test", lf)54assert ctx.tables() == ["foods", "test"]5556out = ctx.execute(57"""58SELECT59grp,60COUNT(DISTINCT att) AS n_dist_attr61FROM test62GROUP BY grp63HAVING n_dist_attr > 164"""65)66assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]}676869def test_group_by_all() -> None:70df = pl.DataFrame(71{72"a": ["xx", "yy", "xx", "yy", "xx", "zz"],73"b": [1, 2, 3, 4, 5, 6],74"c": [99, 99, 66, 66, 66, 66],75}76)7778# basic group/agg79res = df.sql(80"""81SELECT82a,83SUM(b),84SUM(c),85COUNT(*) AS n86FROM self87GROUP BY ALL88ORDER BY ALL89"""90)91expected = pl.DataFrame(92{93"a": ["xx", "yy", "zz"],94"b": [9, 6, 6],95"c": [231, 165, 66],96"n": [3, 2, 1],97}98)99assert_frame_equal(expected, res, check_dtypes=False)100101# more involved determination of agg/group columns102res = df.sql(103"""104SELECT105SUM(b) AS sum_b,106SUM(c) AS sum_c,107(SUM(b) + SUM(c)) / 2.0 AS sum_bc_over_2, -- nested agg108a as grp, --aliased group key109FROM self110GROUP BY ALL111ORDER BY grp112"""113)114expected = pl.DataFrame(115{116"sum_b": [9, 6, 6],117"sum_c": [231, 165, 66],118"sum_bc_over_2": [120.0, 85.5, 36.0],119"grp": ["xx", "yy", "zz"],120}121)122assert_frame_equal(expected, res.sort(by="grp"))123124125def test_group_by_all_multi() -> None:126dt1 = date(1999, 12, 31)127dt2 = date(2028, 7, 5)128129df = pl.DataFrame(130{131"key": ["xx", "yy", "xx", "yy", "xx", "xx"],132"dt": [dt1, dt1, dt1, dt2, dt2, dt2],133"value": [10.5, -5.5, 20.5, 8.0, -3.0, 5.0],134}135)136expected = pl.DataFrame(137{138"dt": [dt1, dt1, dt2, dt2],139"key": ["xx", "yy", "xx", "yy"],140"sum_value": [31.0, -5.5, 2.0, 8.0],141"ninety_nine": [99, 99, 99, 99],142},143schema_overrides={"ninety_nine": pl.Int16},144)145146# the following groupings should all be equivalent147for group in (148"ALL",149"1, 2",150"dt, key",151):152res = df.sql(153f"""154SELECT dt, key, sum_value, ninety_nine::int2 FROM155(156SELECT157dt,158key,159SUM(value) AS sum_value,16099 AS ninety_nine161FROM self162GROUP BY {group}163ORDER BY dt, key164) AS grp165"""166)167assert_frame_equal(expected, res)168169170def test_group_by_ordinal_position() -> None:171df = pl.DataFrame(172{173"a": ["xx", "yy", "xx", "yy", "xx", "zz"],174"b": [1, None, 3, 4, 5, 6],175"c": [99, 99, 66, 66, 66, 66],176}177)178expected = pl.LazyFrame(179{180"c": [66, 99],181"total_b": [18, 1],182"count_b": [4, 1],183"count_star": [4, 2],184}185)186187with pl.SQLContext(frame=df) as ctx:188res1 = ctx.execute(189"""190SELECT191c,192SUM(b) AS total_b,193COUNT(b) AS count_b,194COUNT(*) AS count_star195FROM frame196GROUP BY 1197ORDER BY c198"""199)200assert_frame_equal(res1, expected, check_dtypes=False)201202res2 = ctx.execute(203"""204WITH "grp" AS (205SELECT NULL::date as dt, c, SUM(b) AS total_b206FROM frame207GROUP BY 2, 1208)209SELECT c, total_b FROM grp ORDER BY c"""210)211assert_frame_equal(res2, expected.select(pl.nth(0, 1)))212213214def test_group_by_errors() -> None:215df = pl.DataFrame(216{217"a": ["xx", "yy", "xx"],218"b": [10, 20, 30],219"c": [99, 99, 66],220}221)222223with pytest.raises(224SQLSyntaxError,225match=r"negative ordinal values are invalid for GROUP BY; found -99",226):227df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a")228229with pytest.raises(230SQLSyntaxError,231match=r"GROUP BY requires a valid expression or positive ordinal; found '!!!'",232):233df.sql("SELECT a, SUM(b) FROM self GROUP BY a, '!!!'")234235with pytest.raises(236SQLSyntaxError,237match=r"'a' should participate in the GROUP BY clause or an aggregate function",238):239df.sql("SELECT a, SUM(b) FROM self GROUP BY b")240241with pytest.raises(242SQLSyntaxError,243match=r"HAVING clause not valid outside of GROUP BY",244):245df.sql("SELECT a, COUNT(a) AS n FROM self HAVING n > 1")246247248def test_group_by_output_struct() -> None:249df = pl.DataFrame({"g": [1], "x": [2], "y": [3]})250out = df.group_by("g").agg(pl.struct(pl.col.x.min(), pl.col.y.sum()))251assert out.rows() == [(1, {"x": 2, "y": 3})]252253254@pytest.mark.parametrize(255"maintain_order",256[False, True],257)258def test_group_by_list_cat_24049(maintain_order: bool) -> None:259df = pl.DataFrame(260{261"x": [["a"], ["b", "c"], ["a"], ["a"], ["d"], ["b", "c"]],262"y": [1, 2, 3, 4, 5, 10],263},264schema={"x": pl.List(pl.Categorical), "y": pl.Int32},265)266267expected = pl.DataFrame(268{"x": [["a"], ["b", "c"], ["d"]], "y": [8, 12, 5]},269schema={"x": pl.List(pl.Categorical), "y": pl.Int32},270)271assert_frame_equal(272df.group_by("x", maintain_order=maintain_order).agg(pl.col.y.sum()),273expected,274check_row_order=maintain_order,275)276277278@pytest.mark.parametrize(279"maintain_order",280[False, True],281)282def test_group_by_struct_cat_24049(maintain_order: bool) -> None:283a = {"k1": "a2", "k2": "a2"}284b = {"k1": "b2", "k2": "b2"}285c = {"k1": "c2", "k2": "c2"}286s = pl.Struct({"k1": pl.Categorical, "k2": pl.Categorical})287df = pl.DataFrame(288{289"x": [a, b, a, a, c, b],290"y": [1, 2, 3, 4, 5, 10],291},292schema={"x": s, "y": pl.Int32},293)294295expected = pl.DataFrame(296{"x": [a, b, c], "y": [8, 12, 5]},297schema={"x": s, "y": pl.Int32},298)299assert_frame_equal(300df.group_by("x", maintain_order=maintain_order).agg(pl.col.y.sum()),301expected,302check_row_order=maintain_order,303)304305306