Path: blob/main/py-polars/tests/unit/sql/test_set_ops.py
6939 views
from __future__ import annotations12import pytest34import polars as pl5from polars.exceptions import SQLInterfaceError6from polars.testing import assert_frame_equal789def test_except_intersect() -> None:10df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F84111df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]}) # noqa: F8411213res_e = pl.sql("SELECT x, y, z FROM df1 EXCEPT SELECT * FROM df2", eager=True)14res_i = pl.sql("SELECT * FROM df1 INTERSECT SELECT x, y, z FROM df2", eager=True)1516assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]17assert sorted(res_i.rows()) == [(1, 4, 5)]1819res_e = pl.sql("SELECT * FROM df2 EXCEPT TABLE df1", eager=True)20res_i = pl.sql(21"""22SELECT * FROM df223INTERSECT24SELECT x::int8, y::int8, z::int825FROM (VALUES (1,2,5),(9,3,5),(1,4,5),(1,4,5)) AS df1(x,y,z)26""",27eager=True,28)29assert sorted(res_e.rows()) == [(1, 2, 7), (9, None, 6)]30assert sorted(res_i.rows()) == [(1, 4, 5)]3132# check null behaviour of nulls33with pl.SQLContext(34tbl1=pl.DataFrame({"x": [2, 9, 1], "y": [2, None, 4]}),35tbl2=pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4]}),36) as ctx:37res = ctx.execute("SELECT * FROM tbl1 EXCEPT SELECT * FROM tbl2", eager=True)38assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res)394041def test_except_intersect_by_name() -> None:42df1 = pl.DataFrame( # noqa: F84143{44"x": [1, 9, 1, 1],45"y": [2, 3, 4, 4],46"z": [5, 5, 5, 5],47}48)49df2 = pl.DataFrame( # noqa: F84150{51"y": [2, None, 4],52"w": ["?", "!", "%"],53"z": [7, 6, 5],54"x": [1, 9, 1],55}56)57res_e = pl.sql(58"SELECT x, y, z FROM df1 EXCEPT BY NAME SELECT * FROM df2",59eager=True,60)61res_i = pl.sql(62"SELECT * FROM df1 INTERSECT BY NAME SELECT * FROM df2",63eager=True,64)65assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]66assert sorted(res_i.rows()) == [(1, 4, 5)]67assert res_e.columns == ["x", "y", "z"]68assert res_i.columns == ["x", "y", "z"]697071@pytest.mark.parametrize(72("op", "op_subtype"),73[74("EXCEPT", "ALL"),75("EXCEPT", "ALL BY NAME"),76("INTERSECT", "ALL"),77("INTERSECT", "ALL BY NAME"),78],79)80def test_except_intersect_all_unsupported(op: str, op_subtype: str) -> None:81df1 = pl.DataFrame({"n": [1, 1, 1, 2, 2, 2, 3]}) # noqa: F84182df2 = pl.DataFrame({"n": [1, 1, 2, 2]}) # noqa: F8418384with pytest.raises(85SQLInterfaceError,86match=f"'{op} {op_subtype}' is not supported",87):88pl.sql(f"SELECT * FROM df1 {op} {op_subtype} SELECT * FROM df2")899091def test_update_statement_error() -> None:92df_large = pl.DataFrame(93{94"FQDN": ["c.ORG.na", "a.COM.na"],95"NS1": ["ns1.c.org.na", "ns1.d.net.na"],96"NS2": ["ns2.c.org.na", "ns2.d.net.na"],97"NS3": ["ns3.c.org.na", "ns3.d.net.na"],98}99)100df_small = pl.DataFrame(101{102"FQDN": ["c.org.na"],103"NS1": ["ns1.c.org.na|127.0.0.1"],104"NS2": ["ns2.c.org.na|127.0.0.1"],105"NS3": ["ns3.c.org.na|127.0.0.1"],106}107)108109# Create a context and register the tables110ctx = pl.SQLContext()111ctx.register("large", df_large)112ctx.register("small", df_small)113114with pytest.raises(115SQLInterfaceError,116match="'UPDATE large SET FQDN = u.FQDN, NS1 = u.NS1, NS2 = u.NS2, NS3 = u.NS3 FROM u WHERE large.FQDN = u.FQDN' operation is currently unsupported",117):118ctx.execute("""119WITH u AS (120SELECT121small.FQDN,122small.NS1,123small.NS2,124small.NS3125FROM small126INNER JOIN large ON small.FQDN = large.FQDN127)128UPDATE large129SET130FQDN = u.FQDN,131NS1 = u.NS1,132NS2 = u.NS2,133NS3 = u.NS3134FROM u135WHERE large.FQDN = u.FQDN136""")137138139@pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"])140def test_except_intersect_errors(op: str) -> None:141df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841142df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]}) # noqa: F841143144if op != "UNION":145with pytest.raises(146SQLInterfaceError,147match=f"'{op} ALL' is not supported",148):149pl.sql(f"SELECT * FROM df1 {op} ALL SELECT * FROM df2", eager=False)150151with pytest.raises(152SQLInterfaceError,153match=f"{op} requires equal number of columns in each table",154):155pl.sql(f"SELECT x FROM df1 {op} SELECT x, y FROM df2", eager=False)156157158@pytest.mark.parametrize(159("cols1", "cols2", "union_subtype", "expected"),160[161(162["*"],163["*"],164"",165[(1, "zz"), (2, "yy"), (3, "xx")],166),167(168["*"],169["frame2.*"],170"ALL",171[(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],172),173(174["frame1.*"],175["c1", "c2"],176"DISTINCT",177[(1, "zz"), (2, "yy"), (3, "xx")],178),179(180["*"],181["c2", "c1"],182"ALL BY NAME",183[(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],184),185(186["c1", "c2"],187["c2", "c1"],188"BY NAME",189[(1, "zz"), (2, "yy"), (3, "xx")],190),191pytest.param(192["c1", "c2"],193["c2", "c1"],194"DISTINCT BY NAME",195[(1, "zz"), (2, "yy"), (3, "xx")],196),197],198)199def test_union(200cols1: list[str],201cols2: list[str],202union_subtype: str,203expected: list[tuple[int, str]],204) -> None:205with pl.SQLContext(206frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}),207frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}),208eager=True,209) as ctx:210query = f"""211SELECT {", ".join(cols1)} FROM frame1212UNION {union_subtype}213SELECT {", ".join(cols2)} FROM frame2214"""215assert sorted(ctx.execute(query).rows()) == expected216217218