Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_set_ops.py
8396 views
1
from __future__ import annotations
2
3
import pytest
4
5
import polars as pl
6
from polars.exceptions import SQLInterfaceError
7
from polars.testing import assert_frame_equal
8
from tests.unit.sql import assert_sql_matches
9
10
11
def test_except_intersect() -> None:
12
df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]})
13
df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]})
14
15
res_e = pl.sql("SELECT x, y, z FROM df1 EXCEPT SELECT * FROM df2", eager=True)
16
res_i = pl.sql("SELECT * FROM df1 INTERSECT SELECT x, y, z FROM df2", eager=True)
17
18
assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]
19
assert sorted(res_i.rows()) == [(1, 4, 5)]
20
21
res_e = pl.sql("SELECT * FROM df2 EXCEPT TABLE df1", eager=True)
22
res_i = pl.sql(
23
"""
24
SELECT * FROM df2
25
INTERSECT
26
SELECT x::int8, y::int8, z::int8
27
FROM (VALUES (1,2,5),(9,3,5),(1,4,5),(1,4,5)) AS df1(x,y,z)
28
""",
29
eager=True,
30
)
31
assert sorted(res_e.rows()) == [(1, 2, 7), (9, None, 6)]
32
assert sorted(res_i.rows()) == [(1, 4, 5)]
33
34
# check null behaviour of nulls
35
with pl.SQLContext(
36
tbl1=pl.DataFrame({"x": [2, 9, 1], "y": [2, None, 4]}),
37
tbl2=pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4]}),
38
) as ctx:
39
res = ctx.execute("SELECT * FROM tbl1 EXCEPT SELECT * FROM tbl2", eager=True)
40
assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res)
41
42
43
def test_except_intersect_by_name() -> None:
44
df1 = pl.DataFrame(
45
{
46
"x": [1, 9, 1, 1],
47
"y": [2, 3, 4, 4],
48
"z": [5, 5, 5, 5],
49
}
50
)
51
df2 = pl.DataFrame(
52
{
53
"y": [2, None, 4],
54
"w": ["?", "!", "%"],
55
"z": [7, 6, 5],
56
"x": [1, 9, 1],
57
}
58
)
59
res_e = pl.sql(
60
"SELECT x, y, z FROM df1 EXCEPT BY NAME SELECT * FROM df2",
61
eager=True,
62
)
63
res_i = pl.sql(
64
"SELECT * FROM df1 INTERSECT BY NAME SELECT * FROM df2",
65
eager=True,
66
)
67
assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]
68
assert sorted(res_i.rows()) == [(1, 4, 5)]
69
assert res_e.columns == ["x", "y", "z"]
70
assert res_i.columns == ["x", "y", "z"]
71
72
73
@pytest.mark.parametrize(
74
("op", "op_subtype"),
75
[
76
("EXCEPT", "ALL"),
77
("EXCEPT", "ALL BY NAME"),
78
("INTERSECT", "ALL"),
79
("INTERSECT", "ALL BY NAME"),
80
],
81
)
82
def test_except_intersect_all_unsupported(op: str, op_subtype: str) -> None:
83
df1 = pl.DataFrame({"n": [1, 1, 1, 2, 2, 2, 3]})
84
df2 = pl.DataFrame({"n": [1, 1, 2, 2]})
85
86
with pytest.raises(
87
SQLInterfaceError,
88
match=f"'{op} {op_subtype}' is not supported",
89
):
90
pl.sql(f"SELECT * FROM df1 {op} {op_subtype} SELECT * FROM df2", eager=True)
91
92
93
def test_update_statement_error() -> None:
94
df_large = pl.DataFrame(
95
{
96
"FQDN": ["c.ORG.na", "a.COM.na"],
97
"NS1": ["ns1.c.org.na", "ns1.d.net.na"],
98
"NS2": ["ns2.c.org.na", "ns2.d.net.na"],
99
"NS3": ["ns3.c.org.na", "ns3.d.net.na"],
100
}
101
)
102
df_small = pl.DataFrame(
103
{
104
"FQDN": ["c.org.na"],
105
"NS1": ["ns1.c.org.na|127.0.0.1"],
106
"NS2": ["ns2.c.org.na|127.0.0.1"],
107
"NS3": ["ns3.c.org.na|127.0.0.1"],
108
}
109
)
110
111
# Create a context and register the tables
112
ctx = pl.SQLContext()
113
ctx.register("large", df_large)
114
ctx.register("small", df_small)
115
116
with pytest.raises(
117
SQLInterfaceError,
118
match=r"'UPDATE large SET FQDN = .+ operation is currently unsupported",
119
):
120
ctx.execute("""
121
WITH u AS (
122
SELECT
123
small.FQDN,
124
small.NS1,
125
small.NS2,
126
small.NS3
127
FROM small
128
INNER JOIN large ON small.FQDN = large.FQDN
129
)
130
UPDATE large
131
SET
132
FQDN = u.FQDN,
133
NS1 = u.NS1,
134
NS2 = u.NS2,
135
NS3 = u.NS3
136
FROM u
137
WHERE large.FQDN = u.FQDN
138
""")
139
140
141
@pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"])
142
def test_except_intersect_union_errors(op: str) -> None:
143
df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]})
144
df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]})
145
146
if op != "UNION":
147
with pytest.raises(
148
SQLInterfaceError,
149
match=f"'{op} ALL' is not supported",
150
):
151
pl.sql(
152
f"SELECT * FROM df1 {op} ALL SELECT * FROM df2", eager=False
153
).collect()
154
155
with pytest.raises(
156
SQLInterfaceError,
157
match=f"{op} requires equal number of columns in each table",
158
):
159
pl.sql(f"SELECT x FROM df1 {op} SELECT x, y FROM df2", eager=False).collect()
160
161
162
@pytest.mark.parametrize(
163
("cols1", "cols2", "union_subtype", "expected"),
164
[
165
(
166
["*"],
167
["*"],
168
"",
169
[(1, "zz"), (2, "yy"), (3, "xx")],
170
),
171
(
172
["*"],
173
["frame2.*"],
174
"ALL",
175
[(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],
176
),
177
(
178
["frame1.*"],
179
["c1", "c2"],
180
"DISTINCT",
181
[(1, "zz"), (2, "yy"), (3, "xx")],
182
),
183
(
184
["*"],
185
["c2", "c1"],
186
"ALL BY NAME",
187
[(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],
188
),
189
(
190
["c1", "c2"],
191
["c1 AS x1", "c2 AS x2"],
192
"",
193
[(1, "zz"), (2, "yy"), (3, "xx")],
194
),
195
(
196
["c1", "c2"],
197
["c2", "c1"],
198
"BY NAME",
199
[(1, "zz"), (2, "yy"), (3, "xx")],
200
),
201
pytest.param(
202
["c1", "c2"],
203
["c2", "c1"],
204
"DISTINCT BY NAME",
205
[(1, "zz"), (2, "yy"), (3, "xx")],
206
),
207
],
208
)
209
def test_union(
210
cols1: list[str],
211
cols2: list[str],
212
union_subtype: str,
213
expected: list[tuple[int, str]],
214
) -> None:
215
with pl.SQLContext(
216
frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}),
217
frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}),
218
eager=True,
219
) as ctx:
220
query = f"""
221
SELECT {", ".join(cols1)} FROM frame1
222
UNION {union_subtype}
223
SELECT {", ".join(cols2)} FROM frame2
224
"""
225
assert sorted(ctx.execute(query).rows()) == expected
226
227
228
def test_union_nonmatching_colnames() -> None:
229
# SQL allows "UNION" (aka: polars `concat`) on column names that don't match;
230
# this behaves positionally, with column names coming from the first table
231
with pl.SQLContext(
232
df1=pl.DataFrame(
233
data={"Value": [100, 200], "Tag": ["hello", "foo"]},
234
schema_overrides={"Value": pl.Int16},
235
),
236
df2=pl.DataFrame(
237
data={"Number": [300, 400], "String": ["world", "bar"]},
238
schema_overrides={"Number": pl.Int32},
239
),
240
eager=True,
241
) as ctx:
242
res = ctx.execute(
243
query="""
244
SELECT u.* FROM (
245
SELECT * FROM df1
246
UNION
247
SELECT * FROM df2
248
) u ORDER BY Value
249
"""
250
)
251
assert res.schema == {
252
"Value": pl.Int32,
253
"Tag": pl.String,
254
}
255
assert res.rows() == [
256
(100, "hello"),
257
(200, "foo"),
258
(300, "world"),
259
(400, "bar"),
260
]
261
262
263
def test_union_with_join_state_isolation() -> None:
264
# confirm each branch of a UNION executes with isolated join state;
265
# ensures that aliases from one branch don't leak into the other
266
res = pl.sql(
267
query="""
268
-- start CTEs
269
WITH
270
a AS (SELECT 0 AS k),
271
b AS (SELECT 1 AS k),
272
c AS (SELECT 0 AS k)
273
-- end of CTEs
274
SELECT a.k FROM a JOIN c ON a.k = c.k
275
UNION ALL
276
SELECT b.k FROM b JOIN c ON b.k = c.k
277
""",
278
eager=True,
279
)
280
assert res.to_series().to_list() == [0]
281
282
283
def test_set_operations_order_by() -> None:
284
df1 = pl.DataFrame({"id": [1, 2, 3], "value": [100, 200, 300]})
285
df2 = pl.DataFrame({"id": [4, 5, 6], "value": [400, 500, 600]})
286
df3 = pl.DataFrame({"id": [2, 3, 4], "value": [200, 300, 400]})
287
288
# overall ORDER BY applies to the combined UNION result
289
assert_sql_matches(
290
frames={"df1": df1, "df2": df2},
291
query="""
292
SELECT * FROM df1
293
UNION ALL
294
SELECT * FROM df2
295
ORDER BY id DESC
296
""",
297
expected={
298
"id": [6, 5, 4, 3, 2, 1],
299
"value": [600, 500, 400, 300, 200, 100],
300
},
301
compare_with="sqlite",
302
)
303
304
# ORDER BY with LIMIT on the final result
305
assert_sql_matches(
306
frames={"df1": df1, "df2": df2},
307
query="""
308
SELECT * FROM df1
309
UNION ALL
310
SELECT * FROM df2
311
ORDER BY value DESC
312
LIMIT 3
313
""",
314
expected={"id": [6, 5, 4], "value": [600, 500, 400]},
315
compare_with="sqlite",
316
)
317
318
# ORDER BY with FETCH on the final result
319
assert_sql_matches(
320
frames={"df1": df1, "df2": df2},
321
query="""
322
SELECT * FROM df1
323
UNION ALL
324
SELECT * FROM df2
325
ORDER BY value DESC
326
FETCH FIRST 3 ROWS ONLY
327
""",
328
expected={"id": [6, 5, 4], "value": [600, 500, 400]},
329
compare_with="duckdb",
330
)
331
332
# Nested ORDER BY in subqueries (top-N from each side) with LIMIT
333
assert_sql_matches(
334
frames={"df1": df1, "df2": df2},
335
query="""
336
SELECT * FROM (SELECT * FROM df1 ORDER BY value DESC LIMIT 2) AS top1
337
UNION ALL
338
SELECT * FROM (SELECT * FROM df2 ORDER BY value ASC LIMIT 2) AS top2
339
ORDER BY id
340
""",
341
expected={"id": [2, 3, 4, 5], "value": [200, 300, 400, 500]},
342
compare_with="sqlite",
343
)
344
345
# Nested ORDER BY in subqueries with LIMIT, with an outer ORDER BY/LIMIT
346
assert_sql_matches(
347
{"df1": df1, "df2": df2},
348
query="""
349
SELECT * FROM (
350
SELECT * FROM (SELECT * FROM df1 ORDER BY value DESC LIMIT 2) t1
351
UNION ALL
352
SELECT * FROM (SELECT * FROM df2 ORDER BY value ASC LIMIT 2) t2
353
) t3
354
ORDER BY id
355
LIMIT 3
356
""",
357
expected={"id": [2, 3, 4], "value": [200, 300, 400]},
358
compare_with="sqlite",
359
)
360
361
# EXCEPT with ORDER BY
362
assert_sql_matches(
363
{"df1": df1, "df3": df3},
364
query="""
365
SELECT * FROM df1
366
EXCEPT
367
SELECT * FROM df3
368
ORDER BY id
369
""",
370
expected={"id": [1], "value": [100]},
371
compare_with="sqlite",
372
)
373
374
# INTERSECT with ORDER BY
375
assert_sql_matches(
376
{"df1": df1, "df3": df3},
377
query="""
378
SELECT * FROM df1
379
INTERSECT
380
SELECT * FROM df3
381
ORDER BY id DESC
382
""",
383
expected={"id": [3, 2], "value": [300, 200]},
384
compare_with="sqlite",
385
)
386
387
# INTERSECT with ORDER BY and FETCH (df1 ∩ df3 = {(2,200), (3,300)})
388
assert_sql_matches(
389
{"df1": df1, "df2": df2, "df3": df3},
390
query="""
391
(
392
SELECT * FROM df1
393
UNION
394
SELECT * FROM df2
395
INTERSECT
396
SELECT * FROM df3
397
)
398
ORDER BY id
399
FETCH FIRST 4 ROWS ONLY
400
""",
401
expected={
402
"id": [1, 2, 3, 4],
403
"value": [100, 200, 300, 400],
404
},
405
compare_with="duckdb",
406
)
407
408
# Chained UNION with overall ORDER BY
409
for open_paren, close_paren, compare_with in (
410
("", "", "sqlite"),
411
("", "", "duckdb"),
412
("(", ")", "duckdb"),
413
):
414
assert_sql_matches(
415
{"df1": df1, "df2": df2, "df3": df3},
416
query=f"""
417
{open_paren}
418
SELECT * FROM df1
419
UNION
420
SELECT * FROM df2
421
UNION
422
SELECT * FROM df3
423
{close_paren}
424
ORDER BY value
425
""",
426
expected={
427
"id": [1, 2, 3, 4, 5, 6],
428
"value": [100, 200, 300, 400, 500, 600],
429
},
430
compare_with=compare_with, # type: ignore[arg-type]
431
)
432
433
# UNION with ORDER BY on expression (wrapped in subquery)
434
assert_sql_matches(
435
{"df1": df1, "df2": df2},
436
query="""
437
SELECT * FROM (
438
SELECT id, value FROM df1
439
UNION ALL
440
SELECT id, value FROM df2
441
) AS combined
442
ORDER BY value % 200, id
443
""",
444
expected={
445
"id": [2, 4, 6, 1, 3, 5],
446
"value": [200, 400, 600, 100, 300, 500],
447
},
448
compare_with="sqlite",
449
)
450
451