Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_window_functions.py
7884 views
1
from __future__ import annotations
2
3
import pytest
4
5
import polars as pl
6
from polars.exceptions import SQLInterfaceError
7
from polars.testing import assert_frame_equal
8
from tests.unit.sql import assert_sql_matches
9
10
11
@pytest.fixture
12
def df_test() -> pl.DataFrame:
13
return pl.DataFrame(
14
{
15
"id": [1, 2, 3, 4, 5, 6, 7],
16
"category": ["A", "A", "A", "B", "B", "B", "C"],
17
"value": [20, 10, 30, 15, 40, 25, 35],
18
}
19
)
20
21
22
def test_over_with_order_by(df_test: pl.DataFrame) -> None:
23
query = """
24
SELECT
25
id,
26
value,
27
SUM(value) OVER (ORDER BY value) AS sum_by_value
28
FROM self
29
ORDER BY id
30
"""
31
assert_sql_matches(
32
df_test,
33
query=query,
34
compare_with="sqlite",
35
expected={
36
"id": [1, 2, 3, 4, 5, 6, 7],
37
"value": [20, 10, 30, 15, 40, 25, 35],
38
"sum_by_value": [45, 10, 100, 25, 175, 70, 135],
39
},
40
)
41
42
43
def test_over_with_partition_by(df_test: pl.DataFrame) -> None:
44
df = df_test.remove(pl.col("id") == 6)
45
query = """
46
SELECT
47
category,
48
value,
49
ROW_NUMBER() OVER (PARTITION BY category ORDER BY value) AS row_num,
50
COUNT(*) OVER w0 AS cat_count,
51
SUM(value) OVER w0 AS cat_sum
52
FROM self
53
WINDOW w0 AS (PARTITION BY category)
54
ORDER BY category, value
55
"""
56
assert_sql_matches(
57
df,
58
query=query,
59
compare_with="sqlite",
60
expected={
61
"category": ["A", "A", "A", "B", "B", "C"],
62
"value": [10, 20, 30, 15, 40, 35],
63
"row_num": [1, 2, 3, 1, 2, 1],
64
"cat_count": [3, 3, 3, 2, 2, 1],
65
"cat_sum": [60, 60, 60, 55, 55, 35],
66
},
67
)
68
69
70
def test_over_with_cumulative_window_funcs(df_test: pl.DataFrame) -> None:
71
query = """
72
SELECT
73
category,
74
value,
75
SUM(value) OVER (PARTITION BY category ORDER BY value) AS cumsum,
76
MIN(value) OVER (PARTITION BY category ORDER BY value) AS cummin,
77
MAX(value) OVER (PARTITION BY category ORDER BY value) AS cummax
78
FROM self
79
ORDER BY category, value
80
"""
81
assert_sql_matches(
82
df_test,
83
query=query,
84
compare_with="sqlite",
85
expected={
86
"category": ["A", "A", "A", "B", "B", "B", "C"],
87
"value": [10, 20, 30, 15, 25, 40, 35],
88
"cumsum": [10, 30, 60, 15, 40, 80, 35],
89
"cummin": [10, 10, 10, 15, 15, 15, 35],
90
"cummax": [10, 20, 30, 15, 25, 40, 35],
91
},
92
)
93
94
95
def test_window_function_over_empty(df_test: pl.DataFrame) -> None:
96
query = """
97
SELECT
98
id,
99
COUNT(*) OVER () AS total_count,
100
SUM(value) OVER () AS total_sum
101
FROM self
102
ORDER BY id
103
"""
104
assert_sql_matches(
105
df_test,
106
query=query,
107
compare_with="sqlite",
108
expected={
109
"id": [1, 2, 3, 4, 5, 6, 7],
110
"total_count": [7, 7, 7, 7, 7, 7, 7],
111
"total_sum": [175, 175, 175, 175, 175, 175, 175],
112
},
113
)
114
115
116
def test_window_function_order_by_asc_desc(df_test: pl.DataFrame) -> None:
117
query = """
118
SELECT
119
id,
120
value,
121
SUM(value) OVER (ORDER BY value ASC) AS sum_asc,
122
SUM(value) OVER (ORDER BY value DESC) AS sum_desc,
123
ROW_NUMBER() OVER (ORDER BY value DESC) AS row_num_desc
124
FROM self
125
ORDER BY id
126
"""
127
assert_sql_matches(
128
df_test,
129
query=query,
130
compare_with="sqlite",
131
expected={
132
"id": [1, 2, 3, 4, 5, 6, 7],
133
"value": [20, 10, 30, 15, 40, 25, 35],
134
"sum_asc": [45, 10, 100, 25, 175, 70, 135],
135
"sum_desc": [150, 175, 105, 165, 40, 130, 75],
136
"row_num_desc": [5, 7, 3, 6, 1, 4, 2],
137
},
138
)
139
140
141
def test_window_function_misc_aggregations(df_test: pl.DataFrame) -> None:
142
df = df_test.filter(pl.col("id").is_in([1, 3, 4, 5, 7]))
143
query = """
144
SELECT
145
category,
146
value,
147
COUNT(*) OVER (PARTITION BY category) AS cat_count,
148
SUM(value) OVER (PARTITION BY category) AS cat_sum,
149
AVG(value) OVER (PARTITION BY category) AS cat_avg,
150
COUNT(*) OVER () AS total_count
151
FROM self
152
ORDER BY category, value
153
"""
154
assert_sql_matches(
155
df,
156
query=query,
157
compare_with="sqlite",
158
expected={
159
"category": ["A", "A", "B", "B", "C"],
160
"value": [20, 30, 15, 40, 35],
161
"cat_count": [2, 2, 2, 2, 1],
162
"cat_sum": [50, 50, 55, 55, 35],
163
"cat_avg": [25.0, 25.0, 27.5, 27.5, 35.0],
164
"total_count": [5, 5, 5, 5, 5],
165
},
166
)
167
168
169
def test_window_function_partition_by_multi() -> None:
170
df = pl.DataFrame(
171
{
172
"region": ["North", "North", "North", "South", "South", "South"],
173
"category": ["A", "A", "B", "A", "B", "B"],
174
"value": [10, 20, 15, 30, 25, 35],
175
}
176
)
177
query = """
178
SELECT
179
region,
180
category,
181
value,
182
COUNT(*) OVER (PARTITION BY region, category) AS group_count,
183
SUM(value) OVER (PARTITION BY region, category) AS group_sum
184
FROM self
185
ORDER BY region, category, value
186
"""
187
assert_sql_matches(
188
df,
189
query=query,
190
compare_with="sqlite",
191
expected={
192
"region": ["North", "North", "North", "South", "South", "South"],
193
"category": ["A", "A", "B", "A", "B", "B"],
194
"value": [10, 20, 15, 30, 25, 35],
195
"group_count": [2, 2, 1, 1, 2, 2],
196
"group_sum": [30, 30, 15, 30, 60, 60],
197
},
198
)
199
200
201
def test_window_function_order_by_multi() -> None:
202
df = pl.DataFrame(
203
{
204
"category": ["A", "A", "A", "B", "B"],
205
"subcategory": ["X", "Y", "X", "Y", "X"],
206
"value": [10, 20, 15, 30, 25],
207
}
208
)
209
# Note: Polars uses ROWS semantics, not RANGE semantics; we make that explicit in
210
# the query below so we can compare the result with SQLite as relational databases
211
# usually default to RANGE semantics if not given an explicit frame spec:
212
#
213
# RANGE >> gives peer groups the same value: (A,X) → [25, 25, ...]
214
# ROWS >> gives each row its own cumulative: (A,X) → [10, 25, ...]
215
query = """
216
SELECT
217
category,
218
subcategory,
219
value,
220
SUM(value) OVER (
221
ORDER BY category ASC, subcategory ASC
222
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
223
) AS sum_asc
224
FROM self
225
ORDER BY category, subcategory, value
226
"""
227
assert_sql_matches(
228
df,
229
query=query,
230
compare_with="sqlite",
231
expected={
232
"category": ["A", "A", "A", "B", "B"],
233
"subcategory": ["X", "X", "Y", "X", "Y"],
234
"value": [10, 15, 20, 25, 30],
235
"sum_asc": [10, 25, 45, 70, 100],
236
},
237
)
238
239
query = """
240
SELECT
241
category,
242
subcategory,
243
value,
244
SUM(value) OVER (
245
ORDER BY category DESC, subcategory DESC
246
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
247
) AS sum_desc
248
FROM self
249
ORDER BY category DESC, subcategory DESC, value
250
"""
251
assert_sql_matches(
252
df,
253
query=query,
254
compare_with="sqlite",
255
expected={
256
"category": ["B", "B", "A", "A", "A"],
257
"subcategory": ["Y", "X", "Y", "X", "X"],
258
"value": [30, 25, 20, 10, 15],
259
"sum_desc": [30, 55, 75, 85, 100],
260
},
261
)
262
263
264
def test_window_function_with_nulls() -> None:
265
df = pl.DataFrame(
266
{
267
"category": ["A", "A", None, "B", "B"],
268
"value": [10, None, 15, 30, 25],
269
}
270
)
271
# COUNT with PARTITION BY (where NULL is in the partition)
272
query = """
273
SELECT
274
category,
275
value,
276
COUNT(*) OVER (PARTITION BY category) AS cat_count,
277
COUNT(value) OVER (PARTITION BY category) AS value_count,
278
COUNT(category) OVER () AS cat_count_global
279
FROM self
280
ORDER BY category NULLS LAST, value NULLS FIRST
281
"""
282
assert_sql_matches(
283
df,
284
query=query,
285
check_dtypes=False,
286
compare_with="sqlite",
287
expected={
288
"category": ["A", "A", "B", "B", None],
289
"value": [None, 10, 25, 30, 15],
290
"cat_count": [2, 2, 2, 2, 1],
291
"value_count": [1, 1, 2, 2, 1],
292
"cat_count_global": [4, 4, 4, 4, 4],
293
},
294
)
295
296
297
def test_window_function_min_max(df_test: pl.DataFrame) -> None:
298
df = df_test.filter(pl.col("id").is_in([1, 3, 4, 5, 7]))
299
query = """
300
SELECT
301
category,
302
value,
303
MIN(value) OVER (PARTITION BY category) AS cat_min,
304
MAX(value) OVER (PARTITION BY category) AS cat_max,
305
MIN(value) OVER () AS global_min,
306
MAX(value) OVER () AS global_max
307
FROM self
308
ORDER BY category, value
309
"""
310
assert_sql_matches(
311
df,
312
query=query,
313
compare_with="sqlite",
314
expected={
315
"category": ["A", "A", "B", "B", "C"],
316
"value": [20, 30, 15, 40, 35],
317
"cat_min": [20, 20, 15, 15, 35],
318
"cat_max": [30, 30, 40, 40, 35],
319
"global_min": [15, 15, 15, 15, 15],
320
"global_max": [40, 40, 40, 40, 40],
321
},
322
)
323
324
325
def test_window_function_first_last() -> None:
326
df = pl.DataFrame(
327
{
328
"idx": [6, 5, 4, 3, 2, 1, 0],
329
"category": ["A", "A", "A", "A", "B", "B", "C"],
330
"value": [10, 20, 15, 30, None, 25, 5],
331
}
332
)
333
for first, last, expected_first_last in (
334
(
335
"FIRST_VALUE(value) OVER (PARTITION BY category ORDER BY idx ASC) AS first_val",
336
"LAST_VALUE(value) OVER (PARTITION BY category ORDER BY idx DESC) AS last_val",
337
{
338
"first_val": [30, 30, 30, 30, 25, 25, 5],
339
"last_val": [10, 15, 20, 30, 25, None, 5],
340
},
341
),
342
(
343
"FIRST_VALUE(value) OVER (PARTITION BY category ORDER BY idx DESC) AS first_val",
344
"LAST_VALUE(value) OVER (PARTITION BY category ORDER BY idx ASC) AS last_val",
345
{
346
"first_val": [10, 10, 10, 10, None, None, 5],
347
"last_val": [10, 15, 20, 30, 25, None, 5],
348
},
349
),
350
):
351
query = f"""
352
SELECT category, value, {first}, {last},
353
FROM self ORDER BY category, value
354
"""
355
expected = pl.DataFrame(
356
{
357
"category": ["A", "A", "A", "A", "B", "B", "C"],
358
"value": [10, 15, 20, 30, 25, None, 5],
359
**expected_first_last,
360
}
361
)
362
assert_frame_equal(df.sql(query), expected)
363
assert_sql_matches(df, query=query, compare_with="duckdb", expected=expected)
364
365
366
def test_window_function_over_clause_misc() -> None:
367
df = pl.DataFrame(
368
{
369
"id": [1, 2, 3, 4],
370
"category": ["A", "A", "B", "B"],
371
"value": [10, 20, 30, 40],
372
}
373
)
374
375
# OVER with empty spec
376
query = "SELECT id, COUNT(*) OVER () AS cnt FROM self ORDER BY id"
377
assert_sql_matches(
378
df,
379
query=query,
380
compare_with="sqlite",
381
expected={"id": [1, 2, 3, 4], "cnt": [4, 4, 4, 4]},
382
)
383
384
# OVER with only PARTITION BY
385
query = """
386
SELECT id, category, COUNT(*) OVER (PARTITION BY category) AS count
387
FROM self ORDER BY id
388
"""
389
assert_sql_matches(
390
df,
391
query=query,
392
compare_with="sqlite",
393
expected={
394
"id": [1, 2, 3, 4],
395
"category": ["A", "A", "B", "B"],
396
"count": [2, 2, 2, 2],
397
},
398
)
399
400
# OVER with only ORDER BY
401
query = """
402
SELECT id, value, SUM(value) OVER (ORDER BY value) AS sum_val
403
FROM self ORDER BY id
404
"""
405
assert_sql_matches(
406
df,
407
query=query,
408
compare_with="sqlite",
409
expected={
410
"id": [1, 2, 3, 4],
411
"value": [10, 20, 30, 40],
412
"sum_val": [10, 30, 60, 100],
413
},
414
)
415
416
# OVER with both PARTITION BY and ORDER BY
417
query = """
418
SELECT
419
id,
420
category,
421
value,
422
COUNT(*) OVER (PARTITION BY category ORDER BY value) AS cnt
423
FROM self ORDER BY id
424
"""
425
assert_sql_matches(
426
df,
427
query=query,
428
compare_with="sqlite",
429
expected={
430
"id": [1, 2, 3, 4],
431
"category": ["A", "A", "B", "B"],
432
"value": [10, 20, 30, 40],
433
"cnt": [1, 2, 1, 2],
434
},
435
)
436
437
438
def test_window_named_window(df_test: pl.DataFrame) -> None:
439
# One named window, applied multiple times
440
query = """
441
SELECT
442
category,
443
value,
444
SUM(value) OVER w AS cumsum,
445
MIN(value) OVER w AS cummin,
446
MAX(value) OVER w AS cummax
447
FROM self
448
WINDOW w AS (PARTITION BY category ORDER BY value)
449
ORDER BY category, value
450
"""
451
assert_sql_matches(
452
df_test,
453
query=query,
454
compare_with="sqlite",
455
expected=pl.DataFrame(
456
{
457
"category": ["A", "A", "A", "B", "B", "B", "C"],
458
"value": [10, 20, 30, 15, 25, 40, 35],
459
"cumsum": [10, 30, 60, 15, 40, 80, 35],
460
"cummin": [10, 10, 10, 15, 15, 15, 35],
461
"cummax": [10, 20, 30, 15, 25, 40, 35],
462
}
463
),
464
)
465
466
467
def test_window_multiple_named_windows(df_test: pl.DataFrame) -> None:
468
# Multiple named windows with different properties
469
query = """
470
SELECT
471
category,
472
value,
473
AVG(value) OVER w1 AS category_avg,
474
SUM(value) OVER w2 AS running_sum,
475
COUNT(*) OVER w3 AS total_count
476
FROM self
477
WINDOW
478
w1 AS (PARTITION BY category),
479
w2 AS (ORDER BY value),
480
w3 AS ()
481
ORDER BY category, value
482
"""
483
assert_sql_matches(
484
df_test,
485
query=query,
486
compare_with="sqlite",
487
expected=pl.DataFrame(
488
{
489
"category": ["A", "A", "A", "B", "B", "B", "C"],
490
"value": [10, 20, 30, 15, 25, 40, 35],
491
"category_avg": [
492
20.0,
493
20.0,
494
20.0,
495
26.666667,
496
26.666667,
497
26.666667,
498
35.0,
499
],
500
"running_sum": [10, 45, 100, 25, 70, 175, 135],
501
"total_count": [7, 7, 7, 7, 7, 7, 7],
502
}
503
),
504
)
505
506
507
def test_window_frame_validation() -> None:
508
df = pl.DataFrame({"lbl": ["aa", "cc", "bb"], "value": [50, 75, -100]})
509
510
# Omitted window frame => implicit ROWS semantics
511
# (for Polars; for databases it usually implies RANGE semantics)
512
for query in (
513
"""
514
SELECT lbl, SUM(value) OVER (ORDER BY lbl) AS sum_value
515
FROM self ORDER BY lbl ASC
516
""",
517
"""
518
SELECT lbl, SUM(value) OVER (
519
ORDER BY lbl
520
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
521
) AS sum_value
522
FROM self ORDER BY lbl ASC
523
""",
524
):
525
assert df.sql(query).rows() == [("aa", 50), ("bb", -50), ("cc", 25)]
526
assert_sql_matches(df, query=query, compare_with="sqlite")
527
528
# Rejected: RANGE frame (peer group semantics not supported)
529
query = """
530
SELECT lbl, SUM(value) OVER (
531
ORDER BY lbl
532
RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
533
) AS sum_value
534
FROM self
535
"""
536
with pytest.raises(
537
SQLInterfaceError,
538
match="RANGE-based window frames are not supported",
539
):
540
df.sql(query)
541
542
# Rejected: GROUPS frame
543
query = """
544
SELECT lbl, SUM(value) OVER (
545
ORDER BY lbl
546
GROUPS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
547
) AS sum_value
548
FROM self
549
"""
550
with pytest.raises(
551
SQLInterfaceError,
552
match="GROUPS-based window frames are not supported",
553
):
554
df.sql(query)
555
556
# Rejected: ROWS with incompatible bounds
557
query = """
558
SELECT lbl, SUM(value) OVER (
559
ORDER BY lbl
560
ROWS BETWEEN 1 PRECEDING AND CURRENT ROW
561
) AS sum_value
562
FROM self
563
"""
564
with pytest.raises(
565
SQLInterfaceError,
566
match=(
567
"only 'ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW' is currently "
568
"supported; found 'ROWS BETWEEN 1 PRECEDING AND CURRENT ROW'"
569
),
570
):
571
df.sql(query)
572
573