Path: blob/main/py-polars/tests/unit/sql/test_set_ops.py
8396 views
from __future__ import annotations12import pytest34import polars as pl5from polars.exceptions import SQLInterfaceError6from polars.testing import assert_frame_equal7from tests.unit.sql import assert_sql_matches8910def test_except_intersect() -> None:11df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]})12df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]})1314res_e = pl.sql("SELECT x, y, z FROM df1 EXCEPT SELECT * FROM df2", eager=True)15res_i = pl.sql("SELECT * FROM df1 INTERSECT SELECT x, y, z FROM df2", eager=True)1617assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]18assert sorted(res_i.rows()) == [(1, 4, 5)]1920res_e = pl.sql("SELECT * FROM df2 EXCEPT TABLE df1", eager=True)21res_i = pl.sql(22"""23SELECT * FROM df224INTERSECT25SELECT x::int8, y::int8, z::int826FROM (VALUES (1,2,5),(9,3,5),(1,4,5),(1,4,5)) AS df1(x,y,z)27""",28eager=True,29)30assert sorted(res_e.rows()) == [(1, 2, 7), (9, None, 6)]31assert sorted(res_i.rows()) == [(1, 4, 5)]3233# check null behaviour of nulls34with pl.SQLContext(35tbl1=pl.DataFrame({"x": [2, 9, 1], "y": [2, None, 4]}),36tbl2=pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4]}),37) as ctx:38res = ctx.execute("SELECT * FROM tbl1 EXCEPT SELECT * FROM tbl2", eager=True)39assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res)404142def test_except_intersect_by_name() -> None:43df1 = pl.DataFrame(44{45"x": [1, 9, 1, 1],46"y": [2, 3, 4, 4],47"z": [5, 5, 5, 5],48}49)50df2 = pl.DataFrame(51{52"y": [2, None, 4],53"w": ["?", "!", "%"],54"z": [7, 6, 5],55"x": [1, 9, 1],56}57)58res_e = pl.sql(59"SELECT x, y, z FROM df1 EXCEPT BY NAME SELECT * FROM df2",60eager=True,61)62res_i = pl.sql(63"SELECT * FROM df1 INTERSECT BY NAME SELECT * FROM df2",64eager=True,65)66assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]67assert sorted(res_i.rows()) == [(1, 4, 5)]68assert res_e.columns == ["x", "y", "z"]69assert res_i.columns == ["x", "y", "z"]707172@pytest.mark.parametrize(73("op", "op_subtype"),74[75("EXCEPT", "ALL"),76("EXCEPT", "ALL BY NAME"),77("INTERSECT", "ALL"),78("INTERSECT", "ALL BY NAME"),79],80)81def test_except_intersect_all_unsupported(op: str, op_subtype: str) -> None:82df1 = pl.DataFrame({"n": [1, 1, 1, 2, 2, 2, 3]})83df2 = pl.DataFrame({"n": [1, 1, 2, 2]})8485with pytest.raises(86SQLInterfaceError,87match=f"'{op} {op_subtype}' is not supported",88):89pl.sql(f"SELECT * FROM df1 {op} {op_subtype} SELECT * FROM df2", eager=True)909192def test_update_statement_error() -> None:93df_large = pl.DataFrame(94{95"FQDN": ["c.ORG.na", "a.COM.na"],96"NS1": ["ns1.c.org.na", "ns1.d.net.na"],97"NS2": ["ns2.c.org.na", "ns2.d.net.na"],98"NS3": ["ns3.c.org.na", "ns3.d.net.na"],99}100)101df_small = pl.DataFrame(102{103"FQDN": ["c.org.na"],104"NS1": ["ns1.c.org.na|127.0.0.1"],105"NS2": ["ns2.c.org.na|127.0.0.1"],106"NS3": ["ns3.c.org.na|127.0.0.1"],107}108)109110# Create a context and register the tables111ctx = pl.SQLContext()112ctx.register("large", df_large)113ctx.register("small", df_small)114115with pytest.raises(116SQLInterfaceError,117match=r"'UPDATE large SET FQDN = .+ operation is currently unsupported",118):119ctx.execute("""120WITH u AS (121SELECT122small.FQDN,123small.NS1,124small.NS2,125small.NS3126FROM small127INNER JOIN large ON small.FQDN = large.FQDN128)129UPDATE large130SET131FQDN = u.FQDN,132NS1 = u.NS1,133NS2 = u.NS2,134NS3 = u.NS3135FROM u136WHERE large.FQDN = u.FQDN137""")138139140@pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"])141def test_except_intersect_union_errors(op: str) -> None:142df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]})143df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]})144145if op != "UNION":146with pytest.raises(147SQLInterfaceError,148match=f"'{op} ALL' is not supported",149):150pl.sql(151f"SELECT * FROM df1 {op} ALL SELECT * FROM df2", eager=False152).collect()153154with pytest.raises(155SQLInterfaceError,156match=f"{op} requires equal number of columns in each table",157):158pl.sql(f"SELECT x FROM df1 {op} SELECT x, y FROM df2", eager=False).collect()159160161@pytest.mark.parametrize(162("cols1", "cols2", "union_subtype", "expected"),163[164(165["*"],166["*"],167"",168[(1, "zz"), (2, "yy"), (3, "xx")],169),170(171["*"],172["frame2.*"],173"ALL",174[(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],175),176(177["frame1.*"],178["c1", "c2"],179"DISTINCT",180[(1, "zz"), (2, "yy"), (3, "xx")],181),182(183["*"],184["c2", "c1"],185"ALL BY NAME",186[(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],187),188(189["c1", "c2"],190["c1 AS x1", "c2 AS x2"],191"",192[(1, "zz"), (2, "yy"), (3, "xx")],193),194(195["c1", "c2"],196["c2", "c1"],197"BY NAME",198[(1, "zz"), (2, "yy"), (3, "xx")],199),200pytest.param(201["c1", "c2"],202["c2", "c1"],203"DISTINCT BY NAME",204[(1, "zz"), (2, "yy"), (3, "xx")],205),206],207)208def test_union(209cols1: list[str],210cols2: list[str],211union_subtype: str,212expected: list[tuple[int, str]],213) -> None:214with pl.SQLContext(215frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}),216frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}),217eager=True,218) as ctx:219query = f"""220SELECT {", ".join(cols1)} FROM frame1221UNION {union_subtype}222SELECT {", ".join(cols2)} FROM frame2223"""224assert sorted(ctx.execute(query).rows()) == expected225226227def test_union_nonmatching_colnames() -> None:228# SQL allows "UNION" (aka: polars `concat`) on column names that don't match;229# this behaves positionally, with column names coming from the first table230with pl.SQLContext(231df1=pl.DataFrame(232data={"Value": [100, 200], "Tag": ["hello", "foo"]},233schema_overrides={"Value": pl.Int16},234),235df2=pl.DataFrame(236data={"Number": [300, 400], "String": ["world", "bar"]},237schema_overrides={"Number": pl.Int32},238),239eager=True,240) as ctx:241res = ctx.execute(242query="""243SELECT u.* FROM (244SELECT * FROM df1245UNION246SELECT * FROM df2247) u ORDER BY Value248"""249)250assert res.schema == {251"Value": pl.Int32,252"Tag": pl.String,253}254assert res.rows() == [255(100, "hello"),256(200, "foo"),257(300, "world"),258(400, "bar"),259]260261262def test_union_with_join_state_isolation() -> None:263# confirm each branch of a UNION executes with isolated join state;264# ensures that aliases from one branch don't leak into the other265res = pl.sql(266query="""267-- start CTEs268WITH269a AS (SELECT 0 AS k),270b AS (SELECT 1 AS k),271c AS (SELECT 0 AS k)272-- end of CTEs273SELECT a.k FROM a JOIN c ON a.k = c.k274UNION ALL275SELECT b.k FROM b JOIN c ON b.k = c.k276""",277eager=True,278)279assert res.to_series().to_list() == [0]280281282def test_set_operations_order_by() -> None:283df1 = pl.DataFrame({"id": [1, 2, 3], "value": [100, 200, 300]})284df2 = pl.DataFrame({"id": [4, 5, 6], "value": [400, 500, 600]})285df3 = pl.DataFrame({"id": [2, 3, 4], "value": [200, 300, 400]})286287# overall ORDER BY applies to the combined UNION result288assert_sql_matches(289frames={"df1": df1, "df2": df2},290query="""291SELECT * FROM df1292UNION ALL293SELECT * FROM df2294ORDER BY id DESC295""",296expected={297"id": [6, 5, 4, 3, 2, 1],298"value": [600, 500, 400, 300, 200, 100],299},300compare_with="sqlite",301)302303# ORDER BY with LIMIT on the final result304assert_sql_matches(305frames={"df1": df1, "df2": df2},306query="""307SELECT * FROM df1308UNION ALL309SELECT * FROM df2310ORDER BY value DESC311LIMIT 3312""",313expected={"id": [6, 5, 4], "value": [600, 500, 400]},314compare_with="sqlite",315)316317# ORDER BY with FETCH on the final result318assert_sql_matches(319frames={"df1": df1, "df2": df2},320query="""321SELECT * FROM df1322UNION ALL323SELECT * FROM df2324ORDER BY value DESC325FETCH FIRST 3 ROWS ONLY326""",327expected={"id": [6, 5, 4], "value": [600, 500, 400]},328compare_with="duckdb",329)330331# Nested ORDER BY in subqueries (top-N from each side) with LIMIT332assert_sql_matches(333frames={"df1": df1, "df2": df2},334query="""335SELECT * FROM (SELECT * FROM df1 ORDER BY value DESC LIMIT 2) AS top1336UNION ALL337SELECT * FROM (SELECT * FROM df2 ORDER BY value ASC LIMIT 2) AS top2338ORDER BY id339""",340expected={"id": [2, 3, 4, 5], "value": [200, 300, 400, 500]},341compare_with="sqlite",342)343344# Nested ORDER BY in subqueries with LIMIT, with an outer ORDER BY/LIMIT345assert_sql_matches(346{"df1": df1, "df2": df2},347query="""348SELECT * FROM (349SELECT * FROM (SELECT * FROM df1 ORDER BY value DESC LIMIT 2) t1350UNION ALL351SELECT * FROM (SELECT * FROM df2 ORDER BY value ASC LIMIT 2) t2352) t3353ORDER BY id354LIMIT 3355""",356expected={"id": [2, 3, 4], "value": [200, 300, 400]},357compare_with="sqlite",358)359360# EXCEPT with ORDER BY361assert_sql_matches(362{"df1": df1, "df3": df3},363query="""364SELECT * FROM df1365EXCEPT366SELECT * FROM df3367ORDER BY id368""",369expected={"id": [1], "value": [100]},370compare_with="sqlite",371)372373# INTERSECT with ORDER BY374assert_sql_matches(375{"df1": df1, "df3": df3},376query="""377SELECT * FROM df1378INTERSECT379SELECT * FROM df3380ORDER BY id DESC381""",382expected={"id": [3, 2], "value": [300, 200]},383compare_with="sqlite",384)385386# INTERSECT with ORDER BY and FETCH (df1 ∩ df3 = {(2,200), (3,300)})387assert_sql_matches(388{"df1": df1, "df2": df2, "df3": df3},389query="""390(391SELECT * FROM df1392UNION393SELECT * FROM df2394INTERSECT395SELECT * FROM df3396)397ORDER BY id398FETCH FIRST 4 ROWS ONLY399""",400expected={401"id": [1, 2, 3, 4],402"value": [100, 200, 300, 400],403},404compare_with="duckdb",405)406407# Chained UNION with overall ORDER BY408for open_paren, close_paren, compare_with in (409("", "", "sqlite"),410("", "", "duckdb"),411("(", ")", "duckdb"),412):413assert_sql_matches(414{"df1": df1, "df2": df2, "df3": df3},415query=f"""416{open_paren}417SELECT * FROM df1418UNION419SELECT * FROM df2420UNION421SELECT * FROM df3422{close_paren}423ORDER BY value424""",425expected={426"id": [1, 2, 3, 4, 5, 6],427"value": [100, 200, 300, 400, 500, 600],428},429compare_with=compare_with, # type: ignore[arg-type]430)431432# UNION with ORDER BY on expression (wrapped in subquery)433assert_sql_matches(434{"df1": df1, "df2": df2},435query="""436SELECT * FROM (437SELECT id, value FROM df1438UNION ALL439SELECT id, value FROM df2440) AS combined441ORDER BY value % 200, id442""",443expected={444"id": [2, 4, 6, 1, 3, 5],445"value": [200, 400, 600, 100, 300, 500],446},447compare_with="sqlite",448)449450451