Path: blob/main/py-polars/tests/unit/sql/test_qualify.py
7884 views
from __future__ import annotations12import pytest34import polars as pl5from polars.exceptions import SQLSyntaxError6from tests.unit.sql import assert_sql_matches789@pytest.fixture10def df_test() -> pl.DataFrame:11return pl.DataFrame(12{13"id": [1, 2, 3, 4, 5, 6],14"category": ["A", "A", "A", "B", "B", "B"],15"value": [100, 200, 150, 300, 250, 400],16}17)181920@pytest.mark.parametrize(21"qualify_clause",22[23pytest.param(24"value > AVG(value) OVER (PARTITION BY category)",25id="above_avg",26),27pytest.param(28"value = MAX(value) OVER (PARTITION BY category)",29id="equals_max",30),31pytest.param(32"value > AVG(value) OVER (PARTITION BY category) AND value < 500",33id="compound_expr",34),35],36)37def test_qualify_constraints(df_test: pl.DataFrame, qualify_clause: str) -> None:38assert_sql_matches(39{"df": df_test},40query=f"""41SELECT id, category, value42FROM df43QUALIFY {qualify_clause}44ORDER BY category, value45""",46compare_with="duckdb",47expected={48"id": [2, 6],49"category": ["A", "B"],50"value": [200, 400],51},52)535455def test_qualify_distinct() -> None:56df = pl.DataFrame(57{58"id": [1, 2, 3, 4, 5, 6],59"category": ["A", "A", "B", "B", "C", "C"],60"value": [100, 100, 200, 200, 300, 300],61}62)63assert_sql_matches(64{"df": df},65query="""66SELECT DISTINCT category, value67FROM df68QUALIFY value = MAX(value) OVER (PARTITION BY category)69ORDER BY category70""",71compare_with="duckdb",72expected={73"category": ["A", "B", "C"],74"value": [100, 200, 300],75},76)777879@pytest.mark.parametrize(80"qualify_clause",81[82pytest.param(83"400 < SUM(value) OVER (PARTITION BY category)",84id="sum_window",85),86pytest.param(87"COUNT(*) OVER (PARTITION BY category) = 3",88id="count_window",89),90],91)92def test_qualify_matches_all_rows(df_test: pl.DataFrame, qualify_clause: str) -> None:93assert_sql_matches(94{"df": df_test},95query=f"""96SELECT id, category, value97FROM df98QUALIFY {qualify_clause}99ORDER BY id DESC100""",101compare_with="duckdb",102expected={103"id": [6, 5, 4, 3, 2, 1],104"category": ["B", "B", "B", "A", "A", "A"],105"value": [400, 250, 300, 150, 200, 100],106},107)108109110def test_qualify_multiple_clauses(df_test: pl.DataFrame) -> None:111assert_sql_matches(112{"df": df_test},113query="""114SELECT id, category, value115FROM df116QUALIFY117value >= 300118AND SUM(value) OVER (PARTITION BY category) > 500119ORDER BY value120""",121compare_with="duckdb",122expected={123"id": [4, 6],124"category": ["B", "B"],125"value": [300, 400],126},127)128assert_sql_matches(129{"df": df_test},130query="""131SELECT id, category, value132FROM df133QUALIFY134value = MAX(value) OVER (PARTITION BY category)135OR value = MIN(value) OVER (PARTITION BY category)136ORDER BY id137""",138compare_with="duckdb",139expected={140"id": [1, 2, 5, 6],141"category": ["A", "A", "B", "B"],142"value": [100, 200, 250, 400],143},144)145146147@pytest.mark.parametrize(148"qualify_clause",149[150pytest.param(151"value > MAX(value) OVER (PARTITION BY category)",152id="greater_than_max",153),154pytest.param(155"value < MIN(value) OVER (PARTITION BY category)",156id="less_than_min",157),158],159)160def test_qualify_returns_no_rows(df_test: pl.DataFrame, qualify_clause: str) -> None:161assert_sql_matches(162{"df": df_test},163query=f"""164SELECT id, category, value165FROM df QUALIFY {qualify_clause}166""",167compare_with="duckdb",168expected={"id": [], "category": [], "value": []},169)170171172def test_qualify_using_select_alias(df_test: pl.DataFrame) -> None:173assert_sql_matches(174{"df": df_test},175query="""176SELECT177id,178category,179value,180MAX(value) OVER (PARTITION BY category) as max_value181FROM df182QUALIFY value = max_value183ORDER BY category184""",185compare_with="duckdb",186expected={187"id": [2, 6],188"category": ["A", "B"],189"value": [200, 400],190"max_value": [200, 400],191},192)193194195@pytest.mark.parametrize(196"qualify_clause",197[198pytest.param(199"value > avg_value AND COUNT(*) OVER (PARTITION BY category) = 3",200id="mixed_alias_and_explicit",201),202pytest.param(203"value > AVG(value) OVER (PARTITION BY category)",204id="window_in_select",205),206],207)208def test_qualify_miscellaneous(df_test: pl.DataFrame, qualify_clause: str) -> None:209assert_sql_matches(210{"df": df_test},211query=f"""212SELECT213id,214category,215value,216AVG(value) OVER (PARTITION BY category) as avg_value217FROM df218QUALIFY {qualify_clause}219ORDER BY category220""",221compare_with="duckdb",222expected={223"id": [2, 6],224"category": ["A", "B"],225"value": [200, 400],226"avg_value": [150.0, 316.6666666666667],227},228)229230231def test_qualify_with_internal_cumulative_sum() -> None:232df = pl.DataFrame(233{234"id": [1, 3, 4, 2, 5],235"value": [10, 30, 40, 20, 50],236}237)238assert_sql_matches(239{"df": df},240query="""241SELECT id, value242FROM df243QUALIFY SUM(value) OVER (ORDER BY id) <= 60244ORDER BY id245""",246compare_with="duckdb",247expected={248"id": [1, 2, 3],249"value": [10, 20, 30],250},251)252253254def test_qualify_with_alias_and_comparison(df_test: pl.DataFrame) -> None:255assert_sql_matches(256{"df": df_test},257query="""258SELECT id, SUM(value) OVER (PARTITION BY category) as total259FROM df QUALIFY total > 500260ORDER BY id DESC261""",262compare_with="duckdb",263expected={264"id": [6, 5, 4],265"total": [950, 950, 950],266},267)268269270def test_qualify_with_where_clause(df_test: pl.DataFrame) -> None:271assert_sql_matches(272{"df": df_test},273query="""274SELECT id, category, value275FROM df WHERE value > 200276QUALIFY value != MAX(value) OVER (PARTITION BY category)277ORDER BY value278""",279compare_with="duckdb",280expected={281"id": [5, 4],282"category": ["B", "B"],283"value": [250, 300],284},285)286287288def test_qualify_expected_errors(df_test: pl.DataFrame) -> None:289ctx = pl.SQLContext(df=df_test, eager=True)290with pytest.raises(291SQLSyntaxError,292match="QUALIFY clause must reference window functions",293):294ctx.execute("SELECT id, category, value FROM df QUALIFY value > 200")295296297