Path: blob/main/py-polars/tests/unit/sql/test_window_functions.py
7884 views
from __future__ import annotations12import pytest34import polars as pl5from polars.exceptions import SQLInterfaceError6from polars.testing import assert_frame_equal7from tests.unit.sql import assert_sql_matches8910@pytest.fixture11def df_test() -> pl.DataFrame:12return pl.DataFrame(13{14"id": [1, 2, 3, 4, 5, 6, 7],15"category": ["A", "A", "A", "B", "B", "B", "C"],16"value": [20, 10, 30, 15, 40, 25, 35],17}18)192021def test_over_with_order_by(df_test: pl.DataFrame) -> None:22query = """23SELECT24id,25value,26SUM(value) OVER (ORDER BY value) AS sum_by_value27FROM self28ORDER BY id29"""30assert_sql_matches(31df_test,32query=query,33compare_with="sqlite",34expected={35"id": [1, 2, 3, 4, 5, 6, 7],36"value": [20, 10, 30, 15, 40, 25, 35],37"sum_by_value": [45, 10, 100, 25, 175, 70, 135],38},39)404142def test_over_with_partition_by(df_test: pl.DataFrame) -> None:43df = df_test.remove(pl.col("id") == 6)44query = """45SELECT46category,47value,48ROW_NUMBER() OVER (PARTITION BY category ORDER BY value) AS row_num,49COUNT(*) OVER w0 AS cat_count,50SUM(value) OVER w0 AS cat_sum51FROM self52WINDOW w0 AS (PARTITION BY category)53ORDER BY category, value54"""55assert_sql_matches(56df,57query=query,58compare_with="sqlite",59expected={60"category": ["A", "A", "A", "B", "B", "C"],61"value": [10, 20, 30, 15, 40, 35],62"row_num": [1, 2, 3, 1, 2, 1],63"cat_count": [3, 3, 3, 2, 2, 1],64"cat_sum": [60, 60, 60, 55, 55, 35],65},66)676869def test_over_with_cumulative_window_funcs(df_test: pl.DataFrame) -> None:70query = """71SELECT72category,73value,74SUM(value) OVER (PARTITION BY category ORDER BY value) AS cumsum,75MIN(value) OVER (PARTITION BY category ORDER BY value) AS cummin,76MAX(value) OVER (PARTITION BY category ORDER BY value) AS cummax77FROM self78ORDER BY category, value79"""80assert_sql_matches(81df_test,82query=query,83compare_with="sqlite",84expected={85"category": ["A", "A", "A", "B", "B", "B", "C"],86"value": [10, 20, 30, 15, 25, 40, 35],87"cumsum": [10, 30, 60, 15, 40, 80, 35],88"cummin": [10, 10, 10, 15, 15, 15, 35],89"cummax": [10, 20, 30, 15, 25, 40, 35],90},91)929394def test_window_function_over_empty(df_test: pl.DataFrame) -> None:95query = """96SELECT97id,98COUNT(*) OVER () AS total_count,99SUM(value) OVER () AS total_sum100FROM self101ORDER BY id102"""103assert_sql_matches(104df_test,105query=query,106compare_with="sqlite",107expected={108"id": [1, 2, 3, 4, 5, 6, 7],109"total_count": [7, 7, 7, 7, 7, 7, 7],110"total_sum": [175, 175, 175, 175, 175, 175, 175],111},112)113114115def test_window_function_order_by_asc_desc(df_test: pl.DataFrame) -> None:116query = """117SELECT118id,119value,120SUM(value) OVER (ORDER BY value ASC) AS sum_asc,121SUM(value) OVER (ORDER BY value DESC) AS sum_desc,122ROW_NUMBER() OVER (ORDER BY value DESC) AS row_num_desc123FROM self124ORDER BY id125"""126assert_sql_matches(127df_test,128query=query,129compare_with="sqlite",130expected={131"id": [1, 2, 3, 4, 5, 6, 7],132"value": [20, 10, 30, 15, 40, 25, 35],133"sum_asc": [45, 10, 100, 25, 175, 70, 135],134"sum_desc": [150, 175, 105, 165, 40, 130, 75],135"row_num_desc": [5, 7, 3, 6, 1, 4, 2],136},137)138139140def test_window_function_misc_aggregations(df_test: pl.DataFrame) -> None:141df = df_test.filter(pl.col("id").is_in([1, 3, 4, 5, 7]))142query = """143SELECT144category,145value,146COUNT(*) OVER (PARTITION BY category) AS cat_count,147SUM(value) OVER (PARTITION BY category) AS cat_sum,148AVG(value) OVER (PARTITION BY category) AS cat_avg,149COUNT(*) OVER () AS total_count150FROM self151ORDER BY category, value152"""153assert_sql_matches(154df,155query=query,156compare_with="sqlite",157expected={158"category": ["A", "A", "B", "B", "C"],159"value": [20, 30, 15, 40, 35],160"cat_count": [2, 2, 2, 2, 1],161"cat_sum": [50, 50, 55, 55, 35],162"cat_avg": [25.0, 25.0, 27.5, 27.5, 35.0],163"total_count": [5, 5, 5, 5, 5],164},165)166167168def test_window_function_partition_by_multi() -> None:169df = pl.DataFrame(170{171"region": ["North", "North", "North", "South", "South", "South"],172"category": ["A", "A", "B", "A", "B", "B"],173"value": [10, 20, 15, 30, 25, 35],174}175)176query = """177SELECT178region,179category,180value,181COUNT(*) OVER (PARTITION BY region, category) AS group_count,182SUM(value) OVER (PARTITION BY region, category) AS group_sum183FROM self184ORDER BY region, category, value185"""186assert_sql_matches(187df,188query=query,189compare_with="sqlite",190expected={191"region": ["North", "North", "North", "South", "South", "South"],192"category": ["A", "A", "B", "A", "B", "B"],193"value": [10, 20, 15, 30, 25, 35],194"group_count": [2, 2, 1, 1, 2, 2],195"group_sum": [30, 30, 15, 30, 60, 60],196},197)198199200def test_window_function_order_by_multi() -> None:201df = pl.DataFrame(202{203"category": ["A", "A", "A", "B", "B"],204"subcategory": ["X", "Y", "X", "Y", "X"],205"value": [10, 20, 15, 30, 25],206}207)208# Note: Polars uses ROWS semantics, not RANGE semantics; we make that explicit in209# the query below so we can compare the result with SQLite as relational databases210# usually default to RANGE semantics if not given an explicit frame spec:211#212# RANGE >> gives peer groups the same value: (A,X) → [25, 25, ...]213# ROWS >> gives each row its own cumulative: (A,X) → [10, 25, ...]214query = """215SELECT216category,217subcategory,218value,219SUM(value) OVER (220ORDER BY category ASC, subcategory ASC221ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW222) AS sum_asc223FROM self224ORDER BY category, subcategory, value225"""226assert_sql_matches(227df,228query=query,229compare_with="sqlite",230expected={231"category": ["A", "A", "A", "B", "B"],232"subcategory": ["X", "X", "Y", "X", "Y"],233"value": [10, 15, 20, 25, 30],234"sum_asc": [10, 25, 45, 70, 100],235},236)237238query = """239SELECT240category,241subcategory,242value,243SUM(value) OVER (244ORDER BY category DESC, subcategory DESC245ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW246) AS sum_desc247FROM self248ORDER BY category DESC, subcategory DESC, value249"""250assert_sql_matches(251df,252query=query,253compare_with="sqlite",254expected={255"category": ["B", "B", "A", "A", "A"],256"subcategory": ["Y", "X", "Y", "X", "X"],257"value": [30, 25, 20, 10, 15],258"sum_desc": [30, 55, 75, 85, 100],259},260)261262263def test_window_function_with_nulls() -> None:264df = pl.DataFrame(265{266"category": ["A", "A", None, "B", "B"],267"value": [10, None, 15, 30, 25],268}269)270# COUNT with PARTITION BY (where NULL is in the partition)271query = """272SELECT273category,274value,275COUNT(*) OVER (PARTITION BY category) AS cat_count,276COUNT(value) OVER (PARTITION BY category) AS value_count,277COUNT(category) OVER () AS cat_count_global278FROM self279ORDER BY category NULLS LAST, value NULLS FIRST280"""281assert_sql_matches(282df,283query=query,284check_dtypes=False,285compare_with="sqlite",286expected={287"category": ["A", "A", "B", "B", None],288"value": [None, 10, 25, 30, 15],289"cat_count": [2, 2, 2, 2, 1],290"value_count": [1, 1, 2, 2, 1],291"cat_count_global": [4, 4, 4, 4, 4],292},293)294295296def test_window_function_min_max(df_test: pl.DataFrame) -> None:297df = df_test.filter(pl.col("id").is_in([1, 3, 4, 5, 7]))298query = """299SELECT300category,301value,302MIN(value) OVER (PARTITION BY category) AS cat_min,303MAX(value) OVER (PARTITION BY category) AS cat_max,304MIN(value) OVER () AS global_min,305MAX(value) OVER () AS global_max306FROM self307ORDER BY category, value308"""309assert_sql_matches(310df,311query=query,312compare_with="sqlite",313expected={314"category": ["A", "A", "B", "B", "C"],315"value": [20, 30, 15, 40, 35],316"cat_min": [20, 20, 15, 15, 35],317"cat_max": [30, 30, 40, 40, 35],318"global_min": [15, 15, 15, 15, 15],319"global_max": [40, 40, 40, 40, 40],320},321)322323324def test_window_function_first_last() -> None:325df = pl.DataFrame(326{327"idx": [6, 5, 4, 3, 2, 1, 0],328"category": ["A", "A", "A", "A", "B", "B", "C"],329"value": [10, 20, 15, 30, None, 25, 5],330}331)332for first, last, expected_first_last in (333(334"FIRST_VALUE(value) OVER (PARTITION BY category ORDER BY idx ASC) AS first_val",335"LAST_VALUE(value) OVER (PARTITION BY category ORDER BY idx DESC) AS last_val",336{337"first_val": [30, 30, 30, 30, 25, 25, 5],338"last_val": [10, 15, 20, 30, 25, None, 5],339},340),341(342"FIRST_VALUE(value) OVER (PARTITION BY category ORDER BY idx DESC) AS first_val",343"LAST_VALUE(value) OVER (PARTITION BY category ORDER BY idx ASC) AS last_val",344{345"first_val": [10, 10, 10, 10, None, None, 5],346"last_val": [10, 15, 20, 30, 25, None, 5],347},348),349):350query = f"""351SELECT category, value, {first}, {last},352FROM self ORDER BY category, value353"""354expected = pl.DataFrame(355{356"category": ["A", "A", "A", "A", "B", "B", "C"],357"value": [10, 15, 20, 30, 25, None, 5],358**expected_first_last,359}360)361assert_frame_equal(df.sql(query), expected)362assert_sql_matches(df, query=query, compare_with="duckdb", expected=expected)363364365def test_window_function_over_clause_misc() -> None:366df = pl.DataFrame(367{368"id": [1, 2, 3, 4],369"category": ["A", "A", "B", "B"],370"value": [10, 20, 30, 40],371}372)373374# OVER with empty spec375query = "SELECT id, COUNT(*) OVER () AS cnt FROM self ORDER BY id"376assert_sql_matches(377df,378query=query,379compare_with="sqlite",380expected={"id": [1, 2, 3, 4], "cnt": [4, 4, 4, 4]},381)382383# OVER with only PARTITION BY384query = """385SELECT id, category, COUNT(*) OVER (PARTITION BY category) AS count386FROM self ORDER BY id387"""388assert_sql_matches(389df,390query=query,391compare_with="sqlite",392expected={393"id": [1, 2, 3, 4],394"category": ["A", "A", "B", "B"],395"count": [2, 2, 2, 2],396},397)398399# OVER with only ORDER BY400query = """401SELECT id, value, SUM(value) OVER (ORDER BY value) AS sum_val402FROM self ORDER BY id403"""404assert_sql_matches(405df,406query=query,407compare_with="sqlite",408expected={409"id": [1, 2, 3, 4],410"value": [10, 20, 30, 40],411"sum_val": [10, 30, 60, 100],412},413)414415# OVER with both PARTITION BY and ORDER BY416query = """417SELECT418id,419category,420value,421COUNT(*) OVER (PARTITION BY category ORDER BY value) AS cnt422FROM self ORDER BY id423"""424assert_sql_matches(425df,426query=query,427compare_with="sqlite",428expected={429"id": [1, 2, 3, 4],430"category": ["A", "A", "B", "B"],431"value": [10, 20, 30, 40],432"cnt": [1, 2, 1, 2],433},434)435436437def test_window_named_window(df_test: pl.DataFrame) -> None:438# One named window, applied multiple times439query = """440SELECT441category,442value,443SUM(value) OVER w AS cumsum,444MIN(value) OVER w AS cummin,445MAX(value) OVER w AS cummax446FROM self447WINDOW w AS (PARTITION BY category ORDER BY value)448ORDER BY category, value449"""450assert_sql_matches(451df_test,452query=query,453compare_with="sqlite",454expected=pl.DataFrame(455{456"category": ["A", "A", "A", "B", "B", "B", "C"],457"value": [10, 20, 30, 15, 25, 40, 35],458"cumsum": [10, 30, 60, 15, 40, 80, 35],459"cummin": [10, 10, 10, 15, 15, 15, 35],460"cummax": [10, 20, 30, 15, 25, 40, 35],461}462),463)464465466def test_window_multiple_named_windows(df_test: pl.DataFrame) -> None:467# Multiple named windows with different properties468query = """469SELECT470category,471value,472AVG(value) OVER w1 AS category_avg,473SUM(value) OVER w2 AS running_sum,474COUNT(*) OVER w3 AS total_count475FROM self476WINDOW477w1 AS (PARTITION BY category),478w2 AS (ORDER BY value),479w3 AS ()480ORDER BY category, value481"""482assert_sql_matches(483df_test,484query=query,485compare_with="sqlite",486expected=pl.DataFrame(487{488"category": ["A", "A", "A", "B", "B", "B", "C"],489"value": [10, 20, 30, 15, 25, 40, 35],490"category_avg": [49120.0,49220.0,49320.0,49426.666667,49526.666667,49626.666667,49735.0,498],499"running_sum": [10, 45, 100, 25, 70, 175, 135],500"total_count": [7, 7, 7, 7, 7, 7, 7],501}502),503)504505506def test_window_frame_validation() -> None:507df = pl.DataFrame({"lbl": ["aa", "cc", "bb"], "value": [50, 75, -100]})508509# Omitted window frame => implicit ROWS semantics510# (for Polars; for databases it usually implies RANGE semantics)511for query in (512"""513SELECT lbl, SUM(value) OVER (ORDER BY lbl) AS sum_value514FROM self ORDER BY lbl ASC515""",516"""517SELECT lbl, SUM(value) OVER (518ORDER BY lbl519ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW520) AS sum_value521FROM self ORDER BY lbl ASC522""",523):524assert df.sql(query).rows() == [("aa", 50), ("bb", -50), ("cc", 25)]525assert_sql_matches(df, query=query, compare_with="sqlite")526527# Rejected: RANGE frame (peer group semantics not supported)528query = """529SELECT lbl, SUM(value) OVER (530ORDER BY lbl531RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW532) AS sum_value533FROM self534"""535with pytest.raises(536SQLInterfaceError,537match="RANGE-based window frames are not supported",538):539df.sql(query)540541# Rejected: GROUPS frame542query = """543SELECT lbl, SUM(value) OVER (544ORDER BY lbl545GROUPS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW546) AS sum_value547FROM self548"""549with pytest.raises(550SQLInterfaceError,551match="GROUPS-based window frames are not supported",552):553df.sql(query)554555# Rejected: ROWS with incompatible bounds556query = """557SELECT lbl, SUM(value) OVER (558ORDER BY lbl559ROWS BETWEEN 1 PRECEDING AND CURRENT ROW560) AS sum_value561FROM self562"""563with pytest.raises(564SQLInterfaceError,565match=(566"only 'ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW' is currently "567"supported; found 'ROWS BETWEEN 1 PRECEDING AND CURRENT ROW'"568),569):570df.sql(query)571572573