Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_qualify.py
7884 views
1
from __future__ import annotations
2
3
import pytest
4
5
import polars as pl
6
from polars.exceptions import SQLSyntaxError
7
from tests.unit.sql import assert_sql_matches
8
9
10
@pytest.fixture
11
def df_test() -> pl.DataFrame:
12
return pl.DataFrame(
13
{
14
"id": [1, 2, 3, 4, 5, 6],
15
"category": ["A", "A", "A", "B", "B", "B"],
16
"value": [100, 200, 150, 300, 250, 400],
17
}
18
)
19
20
21
@pytest.mark.parametrize(
22
"qualify_clause",
23
[
24
pytest.param(
25
"value > AVG(value) OVER (PARTITION BY category)",
26
id="above_avg",
27
),
28
pytest.param(
29
"value = MAX(value) OVER (PARTITION BY category)",
30
id="equals_max",
31
),
32
pytest.param(
33
"value > AVG(value) OVER (PARTITION BY category) AND value < 500",
34
id="compound_expr",
35
),
36
],
37
)
38
def test_qualify_constraints(df_test: pl.DataFrame, qualify_clause: str) -> None:
39
assert_sql_matches(
40
{"df": df_test},
41
query=f"""
42
SELECT id, category, value
43
FROM df
44
QUALIFY {qualify_clause}
45
ORDER BY category, value
46
""",
47
compare_with="duckdb",
48
expected={
49
"id": [2, 6],
50
"category": ["A", "B"],
51
"value": [200, 400],
52
},
53
)
54
55
56
def test_qualify_distinct() -> None:
57
df = pl.DataFrame(
58
{
59
"id": [1, 2, 3, 4, 5, 6],
60
"category": ["A", "A", "B", "B", "C", "C"],
61
"value": [100, 100, 200, 200, 300, 300],
62
}
63
)
64
assert_sql_matches(
65
{"df": df},
66
query="""
67
SELECT DISTINCT category, value
68
FROM df
69
QUALIFY value = MAX(value) OVER (PARTITION BY category)
70
ORDER BY category
71
""",
72
compare_with="duckdb",
73
expected={
74
"category": ["A", "B", "C"],
75
"value": [100, 200, 300],
76
},
77
)
78
79
80
@pytest.mark.parametrize(
81
"qualify_clause",
82
[
83
pytest.param(
84
"400 < SUM(value) OVER (PARTITION BY category)",
85
id="sum_window",
86
),
87
pytest.param(
88
"COUNT(*) OVER (PARTITION BY category) = 3",
89
id="count_window",
90
),
91
],
92
)
93
def test_qualify_matches_all_rows(df_test: pl.DataFrame, qualify_clause: str) -> None:
94
assert_sql_matches(
95
{"df": df_test},
96
query=f"""
97
SELECT id, category, value
98
FROM df
99
QUALIFY {qualify_clause}
100
ORDER BY id DESC
101
""",
102
compare_with="duckdb",
103
expected={
104
"id": [6, 5, 4, 3, 2, 1],
105
"category": ["B", "B", "B", "A", "A", "A"],
106
"value": [400, 250, 300, 150, 200, 100],
107
},
108
)
109
110
111
def test_qualify_multiple_clauses(df_test: pl.DataFrame) -> None:
112
assert_sql_matches(
113
{"df": df_test},
114
query="""
115
SELECT id, category, value
116
FROM df
117
QUALIFY
118
value >= 300
119
AND SUM(value) OVER (PARTITION BY category) > 500
120
ORDER BY value
121
""",
122
compare_with="duckdb",
123
expected={
124
"id": [4, 6],
125
"category": ["B", "B"],
126
"value": [300, 400],
127
},
128
)
129
assert_sql_matches(
130
{"df": df_test},
131
query="""
132
SELECT id, category, value
133
FROM df
134
QUALIFY
135
value = MAX(value) OVER (PARTITION BY category)
136
OR value = MIN(value) OVER (PARTITION BY category)
137
ORDER BY id
138
""",
139
compare_with="duckdb",
140
expected={
141
"id": [1, 2, 5, 6],
142
"category": ["A", "A", "B", "B"],
143
"value": [100, 200, 250, 400],
144
},
145
)
146
147
148
@pytest.mark.parametrize(
149
"qualify_clause",
150
[
151
pytest.param(
152
"value > MAX(value) OVER (PARTITION BY category)",
153
id="greater_than_max",
154
),
155
pytest.param(
156
"value < MIN(value) OVER (PARTITION BY category)",
157
id="less_than_min",
158
),
159
],
160
)
161
def test_qualify_returns_no_rows(df_test: pl.DataFrame, qualify_clause: str) -> None:
162
assert_sql_matches(
163
{"df": df_test},
164
query=f"""
165
SELECT id, category, value
166
FROM df QUALIFY {qualify_clause}
167
""",
168
compare_with="duckdb",
169
expected={"id": [], "category": [], "value": []},
170
)
171
172
173
def test_qualify_using_select_alias(df_test: pl.DataFrame) -> None:
174
assert_sql_matches(
175
{"df": df_test},
176
query="""
177
SELECT
178
id,
179
category,
180
value,
181
MAX(value) OVER (PARTITION BY category) as max_value
182
FROM df
183
QUALIFY value = max_value
184
ORDER BY category
185
""",
186
compare_with="duckdb",
187
expected={
188
"id": [2, 6],
189
"category": ["A", "B"],
190
"value": [200, 400],
191
"max_value": [200, 400],
192
},
193
)
194
195
196
@pytest.mark.parametrize(
197
"qualify_clause",
198
[
199
pytest.param(
200
"value > avg_value AND COUNT(*) OVER (PARTITION BY category) = 3",
201
id="mixed_alias_and_explicit",
202
),
203
pytest.param(
204
"value > AVG(value) OVER (PARTITION BY category)",
205
id="window_in_select",
206
),
207
],
208
)
209
def test_qualify_miscellaneous(df_test: pl.DataFrame, qualify_clause: str) -> None:
210
assert_sql_matches(
211
{"df": df_test},
212
query=f"""
213
SELECT
214
id,
215
category,
216
value,
217
AVG(value) OVER (PARTITION BY category) as avg_value
218
FROM df
219
QUALIFY {qualify_clause}
220
ORDER BY category
221
""",
222
compare_with="duckdb",
223
expected={
224
"id": [2, 6],
225
"category": ["A", "B"],
226
"value": [200, 400],
227
"avg_value": [150.0, 316.6666666666667],
228
},
229
)
230
231
232
def test_qualify_with_internal_cumulative_sum() -> None:
233
df = pl.DataFrame(
234
{
235
"id": [1, 3, 4, 2, 5],
236
"value": [10, 30, 40, 20, 50],
237
}
238
)
239
assert_sql_matches(
240
{"df": df},
241
query="""
242
SELECT id, value
243
FROM df
244
QUALIFY SUM(value) OVER (ORDER BY id) <= 60
245
ORDER BY id
246
""",
247
compare_with="duckdb",
248
expected={
249
"id": [1, 2, 3],
250
"value": [10, 20, 30],
251
},
252
)
253
254
255
def test_qualify_with_alias_and_comparison(df_test: pl.DataFrame) -> None:
256
assert_sql_matches(
257
{"df": df_test},
258
query="""
259
SELECT id, SUM(value) OVER (PARTITION BY category) as total
260
FROM df QUALIFY total > 500
261
ORDER BY id DESC
262
""",
263
compare_with="duckdb",
264
expected={
265
"id": [6, 5, 4],
266
"total": [950, 950, 950],
267
},
268
)
269
270
271
def test_qualify_with_where_clause(df_test: pl.DataFrame) -> None:
272
assert_sql_matches(
273
{"df": df_test},
274
query="""
275
SELECT id, category, value
276
FROM df WHERE value > 200
277
QUALIFY value != MAX(value) OVER (PARTITION BY category)
278
ORDER BY value
279
""",
280
compare_with="duckdb",
281
expected={
282
"id": [5, 4],
283
"category": ["B", "B"],
284
"value": [250, 300],
285
},
286
)
287
288
289
def test_qualify_expected_errors(df_test: pl.DataFrame) -> None:
290
ctx = pl.SQLContext(df=df_test, eager=True)
291
with pytest.raises(
292
SQLSyntaxError,
293
match="QUALIFY clause must reference window functions",
294
):
295
ctx.execute("SELECT id, category, value FROM df QUALIFY value > 200")
296
297