Path: blob/main/py-polars/tests/unit/sql/test_subqueries.py
8407 views
import pytest12import polars as pl3from polars.exceptions import SQLInterfaceError, SQLSyntaxError4from polars.testing import assert_frame_equal567@pytest.mark.parametrize(8("cols", "join_type", "constraint"),9[10("x", "INNER", ""),11("y", "INNER", ""),12("x", "LEFT", "WHERE y IN (0,1,2,3,4,5)"),13("y", "LEFT", "WHERE y >= 0"),14("df1.*", "FULL", "WHERE y >= 0"),15("df2.*", "FULL", "WHERE x >= 0"),16("* EXCLUDE y", "LEFT", "WHERE y >= 0"),17("* EXCLUDE x", "LEFT", "WHERE x >= 0"),18],19)20def test_from_subquery(cols: str, join_type: str, constraint: str) -> None:21df1 = pl.DataFrame({"x": [-1, 0, 3, 1, 2, -1]})22df2 = pl.DataFrame({"y": [0, 1, 2, 3]})2324sql = pl.SQLContext(df1=df1, df2=df2)25res = sql.execute(26query=f"""27SELECT {cols} FROM (SELECT * FROM df1) AS df128{join_type} JOIN (SELECT * FROM df2) AS df229ON df1.x = df2.y {constraint}30""",31eager=True,32)33assert sorted(res.to_series()) == [0, 1, 2, 3]343536@pytest.mark.may_fail_cloud # reason: with_context37def test_in_subquery() -> None:38df = pl.DataFrame(39{40"x": [1, 2, 3, 4, 5, 6],41"y": [2, 3, 4, 5, 6, 7],42}43)44df_other = pl.DataFrame(45{46"w": [1, 2, 3, 4, 5, 6],47"z": [2, 3, 4, 5, 6, 7],48}49)50df_chars = pl.DataFrame(51{52"one": ["a", "b", "c", "d", "e", "f"],53"two": ["b", "c", "d", "e", "f", "g"],54}55)5657ctx = pl.SQLContext(df=df, df_other=df_other, df_chars=df_chars)58res_same = ctx.execute(59query="""60SELECT df.x as x61FROM df62WHERE x IN (SELECT y FROM df)63""",64eager=True,65)66df_expected_same = pl.DataFrame({"x": [2, 3, 4, 5, 6]})67assert_frame_equal(68left=df_expected_same,69right=res_same,70)7172res_double = ctx.execute(73query="""74SELECT df.x as x75FROM df76WHERE x IN (SELECT y FROM df)77AND y IN (SELECT w FROM df_other)78""",79eager=True,80)81df_expected_double = pl.DataFrame({"x": [2, 3, 4, 5]})82assert_frame_equal(83left=df_expected_double,84right=res_double,85)8687res_expressions = ctx.execute(88query="""89SELECT90df.x as x91FROM df92WHERE x+1 IN (SELECT y FROM df)93AND y IN (SELECT w-1 FROM df_other)94""",95eager=True,96)97df_expected_expressions = pl.DataFrame({"x": [1, 2, 3, 4]})98assert_frame_equal(99left=df_expected_expressions,100right=res_expressions,101)102103res_not_in = ctx.execute(104query="""105SELECT106df.x as x107FROM df108WHERE x NOT IN (SELECT y-5 FROM df)109AND y NOT IN (SELECT w+5 FROM df_other)110""",111eager=True,112)113df_not_in = pl.DataFrame({"x": [3, 4]})114assert_frame_equal(115left=df_not_in,116right=res_not_in,117)118119res_chars = ctx.execute(120query="""121SELECT122df_chars.one123FROM df_chars124WHERE one IN (SELECT two FROM df_chars)125""",126eager=True,127)128df_expected_chars = pl.DataFrame({"one": ["b", "c", "d", "e", "f"]})129assert_frame_equal(130left=res_chars,131right=df_expected_chars,132)133134with pytest.raises(135expected_exception=SQLSyntaxError,136match="SQL subquery returns more than one column",137):138ctx.execute(139query="""140SELECT141df_chars.one142FROM df_chars143WHERE one IN (SELECT one, two FROM df_chars)144"""145).collect()146147148def test_subquery_20732() -> None:149lf = pl.concat(150[151pl.LazyFrame([{"id": 1, "s": "a"}]),152pl.LazyFrame([{"id": 2, "s": "b"}]),153]154)155res = pl.sql("SELECT * FROM lf WHERE id IN (SELECT MAX(id) FROM lf)", eager=True)156assert res.to_dict(as_series=False) == {"id": [2], "s": ["b"]}157158159def test_unsupported_subquery_comparisons() -> None:160"""Test that using = with a subquery gives a helpful error message."""161df = pl.DataFrame({"value": [2000, 2000]})162163for op, suggestion in (("=", "IN"), ("!=", "NOT IN")):164with pytest.raises(165expected_exception=SQLSyntaxError,166match=rf"subquery comparisons with '{op}' are not supported; use '{suggestion}' instead",167):168pl.sql(f"SELECT * FROM df WHERE value {op} (SELECT MAX(e) FROM df)")169170for op in ("<", "<=", ">", ">="):171with pytest.raises(172expected_exception=SQLSyntaxError,173match=rf"subquery comparisons with '{op}' are not supported",174):175pl.sql(f"SELECT * FROM df WHERE (SELECT MAX(e) FROM df) {op} value")176177with pytest.raises(178expected_exception=SQLSyntaxError,179match=rf"subquery comparisons with '{op}' are not supported",180):181pl.sql(f"SELECT * FROM df WHERE value {op} (SELECT MAX(value) FROM df)")182183184def test_derived_table_without_alias() -> None:185df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})186187# basic unaliased subquery188with pl.SQLContext(df=df) as ctx:189res = ctx.execute("SELECT * FROM (SELECT a, b FROM df) ORDER BY a", eager=True)190assert_frame_equal(res, df)191192# set operation without subquery aliases193res = ctx.execute(194"""195SELECT * FROM (196SELECT a, b FROM df WHERE a <= 2197UNION ALL198SELECT a, b FROM df WHERE a > 2199)200ORDER BY a201"""202).collect()203assert_frame_equal(res, df)204205# unqualified (but unambiguous) column refs from unaliased derived table206res = ctx.execute("SELECT a FROM (SELECT a, b FROM df) ORDER BY a", eager=True)207assert_frame_equal(res, df.select("a"))208209210def test_derived_table_alias_errors() -> None:211df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})212213# joining on unaliased derived table should raise214for join_type in ("INNER", "LEFT", "CROSS"):215constraint = "" if join_type == "CROSS" else "ON df.a = a2"216with pytest.raises(217expected_exception=SQLInterfaceError,218match="cannot JOIN on unnamed relation",219):220pl.sql(221query=f"""222SELECT * FROM df223{join_type} JOIN (SELECT a AS a2 FROM df) {constraint}224"""225).collect()226227# unaliased derived tables in a join228with pytest.raises(229expected_exception=SQLInterfaceError,230match="cannot JOIN on unnamed relation",231):232pl.sql(233query="""234SELECT *235FROM (SELECT a FROM df)236INNER JOIN (SELECT b FROM df) ON a = b237""",238).collect()239240# qualified wildcard on nonexistent alias241with pytest.raises(242expected_exception=SQLInterfaceError,243match="no table or struct column named 'sq' found",244):245pl.sql(246query="SELECT sq.* FROM (SELECT a, b FROM df)",247eager=True,248)249250# qualified column reference on nonexistent alias251with pytest.raises(252expected_exception=SQLInterfaceError,253match="no table or struct column named 'sq' found",254):255pl.sql(256query="SELECT sq.a FROM (SELECT a, b FROM df)",257eager=True,258)259260# qualified reference in different clauses261with pytest.raises(262expected_exception=SQLInterfaceError,263match="no table or struct column named 'sq' found",264):265pl.sql(266query="SELECT a FROM (SELECT a, b FROM df) WHERE sq.a > 1",267eager=True,268)269270with pytest.raises(271expected_exception=SQLInterfaceError,272match="no table or struct column named 'sq' found",273):274pl.sql(275query="SELECT a, COUNT(*) FROM (SELECT a, b FROM df) GROUP BY sq.a",276eager=True,277)278279with pytest.raises(280expected_exception=SQLInterfaceError,281match="no table or struct column named 'sq' found",282):283pl.sql(284query="SELECT a FROM (SELECT a, b FROM df) ORDER BY sq.a",285eager=True,286)287288289