Path: blob/main/py-polars/tests/unit/sql/test_miscellaneous.py
8413 views
from __future__ import annotations12from datetime import date3from pathlib import Path4from typing import TYPE_CHECKING, Any56import pytest78import polars as pl9from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError10from polars.testing import assert_frame_equal11from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder1213if TYPE_CHECKING:14from polars.datatypes import DataType151617@pytest.fixture18def foods_ipc_path() -> Path:19return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"202122def test_any_all() -> None:23df = pl.DataFrame(24{25"x": [-1, 0, 1, 2, 3, 4],26"y": [1, 0, 0, 1, 2, 3],27}28)29res = pl.sql(30"""31SELECT32x >= ALL(df.y) AS "All Geq",33x > ALL(df.y) AS "All G",34x < ALL(df.y) AS "All L",35x <= ALL(df.y) AS "All Leq",36x >= ANY(df.y) AS "Any Geq",37x > ANY(df.y) AS "Any G",38x < ANY(df.y) AS "Any L",39x <= ANY(df.y) AS "Any Leq",40x == ANY(df.y) AS "Any eq",41x != ANY(df.y) AS "Any Neq",42FROM df43""",44).collect()4546assert res.to_dict(as_series=False) == {47"All Geq": [0, 0, 0, 0, 1, 1],48"All G": [0, 0, 0, 0, 0, 1],49"All L": [1, 0, 0, 0, 0, 0],50"All Leq": [1, 1, 0, 0, 0, 0],51"Any Geq": [0, 1, 1, 1, 1, 1],52"Any G": [0, 0, 1, 1, 1, 1],53"Any L": [1, 1, 1, 1, 0, 0],54"Any Leq": [1, 1, 1, 1, 1, 0],55"Any eq": [0, 1, 1, 1, 1, 0],56"Any Neq": [1, 0, 0, 0, 0, 1],57}585960@pytest.mark.parametrize(61("data", "schema"),62[63({"x": [1, 2, 3, 4]}, None),64({"x": [9, 8, 7, 6]}, {"x": pl.Int8}),65({"x": ["aa", "bb"]}, {"x": pl.Struct}),66({"x": [None, None], "y": [None, None]}, {"x": pl.Date, "y": pl.Float64}),67],68)69def test_boolean_where_clauses(70data: dict[str, Any], schema: dict[str, DataType] | None71) -> None:72df = pl.DataFrame(data=data, schema=schema)73empty_df = df.clear()7475for true in ("TRUE", "1=1", "2 == 2", "'xx' = 'xx'", "TRUE AND 1=1"):76assert_frame_equal(df, df.sql(f"SELECT * FROM self WHERE {true}"))7778for false in ("false", "1!=1", "2 != 2", "'xx' != 'xx'", "FALSE OR 1!=1"):79assert_frame_equal(empty_df, df.sql(f"SELECT * FROM self WHERE {false}"))808182def test_count() -> None:83df = pl.DataFrame(84{85"a": [1, 2, 3, 4, 5],86"b": [1, 1, 22, 22, 333],87"c": [1, 1, None, None, 2],88}89)90res = df.sql(91"""92SELECT93-- count94COUNT(a) AS count_a,95COUNT(b) AS count_b,96COUNT(c) AS count_c,97COUNT(*) AS count_star,98COUNT(NULL) AS count_null,99-- count distinct100COUNT(DISTINCT a) AS count_unique_a,101COUNT(DISTINCT b) AS count_unique_b,102COUNT(DISTINCT c) AS count_unique_c,103COUNT(DISTINCT NULL) AS count_unique_null,104FROM self105""",106)107assert res.to_dict(as_series=False) == {108"count_a": [5],109"count_b": [5],110"count_c": [3],111"count_star": [5],112"count_null": [0],113"count_unique_a": [5],114"count_unique_b": [3],115"count_unique_c": [2],116"count_unique_null": [0],117}118119df = pl.DataFrame({"x": [None, None, None]})120res = df.sql(121"""122SELECT123COUNT(x) AS count_x,124COUNT(*) AS count_star,125COUNT(DISTINCT x) AS count_unique_x126FROM self127"""128)129assert res.to_dict(as_series=False) == {130"count_x": [0],131"count_star": [3],132"count_unique_x": [0],133}134135136def test_cte_aliasing() -> None:137df1 = pl.DataFrame({"colx": ["aa", "bb"], "coly": [40, 30]})138df2 = pl.DataFrame({"colx": "aa", "colz": 20})139df3 = pl.sql(140query="""141WITH142test1 AS (SELECT * FROM df1),143test2 AS (SELECT * FROM df2),144test3 AS (145SELECT ROW_NUMBER() OVER (ORDER BY t1.colx) AS n, t1.colx, t2.colz146FROM test1 t1147LEFT JOIN test2 t2 ON t1.colx = t2.colx148)149SELECT * FROM test3 t3 ORDER BY colx DESC150""",151eager=True,152)153expected = [(2, "bb", None), (1, "aa", 20)]154assert expected == df3.rows()155156157def test_distinct() -> None:158df = pl.DataFrame(159{160"a": [1, 1, 1, 2, 2, 3],161"b": [1, 2, 3, 4, 5, 6],162}163)164ctx = pl.SQLContext(register_globals=True, eager=True)165res1 = ctx.execute("SELECT DISTINCT a FROM df ORDER BY a DESC")166assert_frame_equal(167left=df.select("a").unique().sort(by="a", descending=True),168right=res1,169)170171res2 = ctx.execute(172"""173SELECT DISTINCT174a * 2 AS two_a,175b / 2 AS half_b176FROM df177ORDER BY two_a ASC, half_b DESC178""",179)180assert res2.to_dict(as_series=False) == {181"two_a": [2, 2, 4, 6],182"half_b": [1, 0, 2, 3],183}184185# test unregistration186ctx.unregister("df")187with pytest.raises(SQLInterfaceError, match="relation 'df' was not found"):188ctx.execute("SELECT * FROM df")189190191def test_frame_sql_globals_error() -> None:192df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})193df2 = pl.DataFrame({"a": [2, 3, 4], "b": [7, 6, 5]})194195query = """196SELECT df1.a, df2.b197FROM df2 JOIN df1 ON df1.a = df2.a198ORDER BY b DESC199"""200with pytest.raises(SQLInterfaceError, match=r"relation.*not found.*"):201df1.sql(query=query)202203res = pl.sql(query=query, eager=True)204assert res.to_dict(as_series=False) == {"a": [2, 3], "b": [7, 6]}205206207def test_global_misc_lookup() -> None:208# check that `col` in global namespace is not incorrectly identified209# as supporting pycapsule (as it can look like it has *any* attr)210from polars import col # noqa: F401211212df = pl.DataFrame({"col": [90, 80, 70]})213df_res = pl.sql("SELECT col FROM df WHERE col > 75", eager=True)214assert df_res.rows() == [(90,), (80,)]215216217def test_in_no_ops_11946() -> None:218lf = pl.LazyFrame(219[220{"i1": 1},221{"i1": 2},222{"i1": 3},223]224)225out = lf.sql(226query="SELECT * FROM frame_data WHERE i1 in (1, 3)",227table_name="frame_data",228).collect()229assert out.to_dict(as_series=False) == {"i1": [1, 3]}230231232def test_limit_offset() -> None:233n_values = 11234lf = pl.LazyFrame({"a": range(n_values), "b": reversed(range(n_values))})235ctx = pl.SQLContext(tbl=lf)236237assert ctx.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [238(4, 6),239(5, 5),240(6, 4),241]242for offset, limit in [(0, 3), (1, n_values), (2, 3), (5, 3), (8, 5), (n_values, 1)]:243out = ctx.execute(244f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True245)246assert_frame_equal(out, lf.slice(offset, limit).collect())247assert len(out) == min(limit, n_values - offset)248249250def test_nested_subquery_table_leakage() -> None:251a = pl.LazyFrame({"id": [1, 2, 3]})252b = pl.LazyFrame({"val": [2, 3, 4]})253254ctx = pl.SQLContext(a=a, b=b)255ctx.execute("""256SELECT *257FROM a258WHERE id IN (259SELECT derived.val260FROM (SELECT val FROM b) AS derived261)262""")263264# after execution of the above query, confirm that we don't see the265# inner "derived" table alias still being registered in the context266with pytest.raises(267SQLInterfaceError,268match="relation 'derived' was not found",269):270ctx.execute("SELECT * FROM derived")271272273def test_register_context() -> None:274# context manager usage should unregister tables created in each275# scope on context exit; supports arbitrary levels of nesting.276with pl.SQLContext() as ctx:277_lf1 = pl.LazyFrame({"a": [1, 2, 3], "b": ["m", "n", "o"]})278_lf2 = pl.LazyFrame({"a": [2, 3, 4], "c": ["p", "q", "r"]})279280ctx.register_globals()281assert ctx.tables() == ["_lf1", "_lf2"]282283with ctx:284_lf3 = pl.LazyFrame({"a": [3, 4, 5], "b": ["s", "t", "u"]})285_lf4 = pl.LazyFrame({"a": [4, 5, 6], "c": ["v", "w", "x"]})286ctx.register_globals(n=2)287assert ctx.tables() == ["_lf1", "_lf2", "_lf3", "_lf4"]288289assert ctx.tables() == ["_lf1", "_lf2"]290291assert ctx.tables() == []292293294def test_sql_on_compatible_frame_types() -> None:295df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})296297# create various different frame types298dfp = df.to_pandas()299dfa = df.to_arrow()300dfb = dfa.to_batches()[0]301dfo = PyCapsuleStreamHolder(df)302303# run polars sql query against all frame types304for dfs in ( # noqa: B007305(df["a"] * 2).rename("c"), # polars series306(dfp["a"] * 2).rename("c"), # pandas series307):308res = pl.sql(309"""310SELECT a, b, SUM(c) AS cc FROM (311SELECT * FROM df -- polars frame312UNION ALL SELECT * FROM dfp -- pandas frame313UNION ALL SELECT * FROM dfa -- pyarrow table314UNION ALL SELECT * FROM dfb -- pyarrow record batch315UNION ALL SELECT * FROM dfo -- arbitrary pycapsule object316) tbl317INNER JOIN dfs ON dfs.c == tbl.b -- join on pandas/polars series318GROUP BY "a", "b"319ORDER BY "a", "b"320"""321).collect()322323expected = pl.DataFrame({"a": [1, 3], "b": [4, 6], "cc": [20, 30]})324assert_frame_equal(left=expected, right=res)325326# register and operate on non-polars frames327for obj in (dfa, dfp):328with pl.SQLContext(obj=obj) as ctx:329res = ctx.execute("SELECT * FROM obj", eager=True)330assert_frame_equal(df, res)331332# don't register all compatible objects333with pytest.raises(SQLInterfaceError, match="relation 'dfp' was not found"):334pl.SQLContext(register_globals=True).execute("SELECT * FROM dfp")335336337def test_nested_cte_column_aliasing() -> None:338# trace through nested CTEs with multiple levels of column & table aliasing339df = pl.sql(340"""341WITH342x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)),343y (m, n) AS (344WITH z(c, d) AS (SELECT a, b FROM x)345SELECT d*2 AS d2, c*3 AS c3 FROM z346)347SELECT n, m FROM y348""",349eager=True,350)351assert df.to_dict(as_series=False) == {352"n": [3, 9],353"m": [4, 8],354}355356357def test_invalid_derived_table_column_aliases() -> None:358values_query = "SELECT * FROM (VALUES (1,2), (3,4))"359360with pytest.raises(361SQLSyntaxError,362match=r"columns \(5\) in alias 'tbl' does not match .* the table/query \(2\)",363):364pl.sql(f"{values_query} AS tbl(a, b, c, d, e)")365366assert pl.sql(f"{values_query} tbl", eager=True).rows() == [(1, 2), (3, 4)]367368369def test_values_clause_table_registration() -> None:370with pl.SQLContext(frames=None, eager=True) as ctx:371# initially no tables are registered372assert ctx.tables() == []373374# confirm that VALUES clause derived table is registered, post-query375res1 = ctx.execute("SELECT * FROM (VALUES (-1,1)) AS tbl(x, y)")376assert ctx.tables() == ["tbl"]377378# and confirm that we can select from it by the registered name379res2 = ctx.execute("SELECT x, y FROM tbl")380for res in (res1, res2):381assert res.to_dict(as_series=False) == {"x": [-1], "y": [1]}382383384def test_read_csv(tmp_path: Path) -> None:385# check empty string vs null, parsing of dates, etc386df = pl.DataFrame(387{388"label": ["lorem", None, "", "ipsum"],389"num": [-1, None, 0, 1],390"dt": [391date(1969, 7, 5),392date(1999, 12, 31),393date(2077, 10, 10),394None,395],396}397)398csv_target = tmp_path / "test_sql_read.csv"399df.write_csv(csv_target)400401res = pl.sql(f"SELECT * FROM read_csv('{csv_target}')").collect()402assert_frame_equal(df, res)403404with pytest.raises(405SQLSyntaxError,406match="`read_csv` expects a single file path; found 3 arguments",407):408pl.sql("SELECT * FROM read_csv('a','b','c')")409410411def test_global_variable_inference_17398() -> None:412users = pl.DataFrame({"id": "1"})413414res = pl.sql(415query="""416WITH user_by_email AS (SELECT id FROM users)417SELECT * FROM user_by_email418""",419eager=True,420)421assert_frame_equal(res, users)422423424@pytest.mark.parametrize(425"query",426[427"SELECT invalid_column FROM self",428"SELECT key, invalid_column FROM self",429"SELECT invalid_column * 2 FROM self",430"SELECT * FROM self ORDER BY invalid_column",431"SELECT * FROM self WHERE invalid_column = 200",432"SELECT * FROM self WHERE invalid_column = '200'",433"SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column",434],435)436def test_invalid_cols(query: str) -> None:437df = pl.DataFrame(438{439"key": ["xx", "xx", "yy"],440"n": ["100", "200", "300"],441}442)443with pytest.raises(ColumnNotFoundError, match="invalid_column"):444df.sql(query)445446447@pytest.mark.parametrize("filter_expr", ["", "WHERE 1 = 1", "WHERE a == 1 OR a != 1"])448@pytest.mark.parametrize("order_expr", ["", "ORDER BY 1", "ORDER BY a"])449def test_select_output_heights_20058_21084(filter_expr: str, order_expr: str) -> None:450df = pl.DataFrame({"a": [1, 2, 3]})451452# Queries that maintain original height453454assert_frame_equal(455df.sql(f"SELECT 1 as a FROM self {filter_expr} {order_expr}").cast(pl.Int64),456pl.select(a=pl.Series([1, 1, 1])),457)458459assert_frame_equal(460df.sql(f"SELECT 1 + 1 as a, 1 as b FROM self {filter_expr} {order_expr}").cast(461pl.Int64462),463pl.DataFrame({"a": [2, 2, 2], "b": [1, 1, 1]}),464)465466# Queries that aggregate to unit height467468assert_frame_equal(469df.sql(f"SELECT COUNT(*) as a FROM self {filter_expr} {order_expr}").cast(470pl.Int64471),472pl.DataFrame({"a": 3}),473)474475assert_frame_equal(476df.sql(477f"SELECT COUNT(*) as a, 1 as b FROM self {filter_expr} {order_expr}"478).cast(pl.Int64),479pl.DataFrame({"a": 3, "b": 1}),480)481482assert_frame_equal(483df.sql(484f"SELECT FIRST(a) as a, 1 as b FROM self {filter_expr} {order_expr}"485).cast(pl.Int64),486pl.DataFrame({"a": 1, "b": 1}),487)488489assert_frame_equal(490df.sql(f"SELECT SUM(a) as a, 1 as b FROM self {filter_expr} {order_expr}").cast(491pl.Int64492),493pl.DataFrame({"a": 6, "b": 1}),494)495496assert_frame_equal(497df.sql(498f"SELECT FIRST(1) as a, 1 as b FROM self {filter_expr} {order_expr}"499).cast(pl.Int64),500pl.DataFrame({"a": 1, "b": 1}),501)502503assert_frame_equal(504df.sql(505f"SELECT FIRST(1) + 1 as a, 1 as b FROM self {filter_expr} {order_expr}"506).cast(pl.Int64),507pl.DataFrame({"a": 2, "b": 1}),508)509510assert_frame_equal(511df.sql(512f"SELECT FIRST(1 + 1) as a, 1 as b FROM self {filter_expr} {order_expr}"513).cast(pl.Int64),514pl.DataFrame({"a": 2, "b": 1}),515)516517518def test_select_explode_height_filter_order_by() -> None:519# Note: `unnest()` from SQL equates to `pl.Dataframe.explode()520# The ordering is applied after the explosion/unnest.521# `522df = pl.DataFrame(523{524"list_long": [[1, 2, 3], [4, 5, 6]],525"sort_key": [2, 1],526"filter_mask": [False, True],527"filter_mask_all_true": True,528}529)530531# Unnest/explode is applied at the dataframe level, sort is applied afterward532assert_frame_equal(533df.sql("SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key"),534pl.Series("list", [4, 5, 6, 1, 2, 3]).to_frame(),535)536537# No NULLS: since order is applied after explode on the dataframe level538assert_frame_equal(539df.sql(540"SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key NULLS FIRST"541),542pl.Series("list", [4, 5, 6, 1, 2, 3]).to_frame(),543)544545# Literals are broadcasted to output height of UNNEST:546assert_frame_equal(547df.sql("SELECT UNNEST(list_long) as list, 1 as x FROM self ORDER BY sort_key"),548pl.select(pl.Series("list", [4, 5, 6, 1, 2, 3]), x=1),549)550551# Note: Filter applies before projections in SQL552assert_frame_equal(553df.sql(554"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask ORDER BY sort_key"555),556pl.Series("list", [4, 5, 6]).to_frame(),557)558559assert_frame_equal(560df.sql(561"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask_all_true ORDER BY sort_key"562),563pl.Series("list", [4, 5, 6, 1, 2, 3]).to_frame(),564)565566567@pytest.mark.parametrize(568("query", "result"),569[570(571"""SELECT a, COUNT(*) OVER (PARTITION BY a) AS b FROM self""",572[3, 3, 3, 1, 3, 3, 3],573),574(575"""SELECT a, COUNT() OVER (PARTITION BY a) AS b FROM self""",576[3, 3, 3, 1, 3, 3, 3],577),578(579"""SELECT a, COUNT(i) OVER (PARTITION BY a) AS b FROM self""",580[3, 3, 3, 1, 1, 1, 1],581),582(583"""SELECT a, COUNT(DISTINCT i) OVER (PARTITION BY a) AS b FROM self""",584[2, 2, 2, 1, 1, 1, 1],585),586],587)588def test_count_partition_22665(query: str, result: list[Any]) -> None:589df = pl.DataFrame(590{591"a": [1, 1, 1, 2, 3, 3, 3],592"i": [0, 0, 1, 2, 3, None, None],593}594)595out = df.sql(query).select("b")596expected = pl.DataFrame({"b": result}).cast({"b": pl.get_index_type()})597assert_frame_equal(out, expected)598599600@pytest.mark.parametrize(601"query",602[603# ClickHouse-specific PREWHERE clause604"SELECT x, y FROM df PREWHERE z IS NOT NULL",605# LATERAL VIEW syntax606"SELECT * FROM person LATERAL VIEW EXPLODE(ARRAY(0,125)) tableName AS age",607# Oracle-style hierarchical queries608"""609SELECT employee_id, employee_name, manager_id, LEVEL AS hierarchy_level610FROM employees611START WITH manager_id IS NULL612CONNECT BY PRIOR employee_id = manager_id613""",614],615)616def test_unsupported_select_clauses(query: str) -> None:617# ensure we're actively catching unsupported clauses618with (619pl.SQLContext() as ctx,620pytest.raises(621SQLInterfaceError,622match=r"not.*supported",623),624):625ctx.execute(query)626627628