Path: blob/main/py-polars/tests/unit/sql/test_miscellaneous.py
6939 views
from __future__ import annotations12from datetime import date3from pathlib import Path4from typing import TYPE_CHECKING, Any56import pytest78import polars as pl9from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError10from polars.testing import assert_frame_equal11from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder1213if TYPE_CHECKING:14from polars.datatypes import DataType151617@pytest.fixture18def foods_ipc_path() -> Path:19return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"202122def test_any_all() -> None:23df = pl.DataFrame( # noqa: F84124{25"x": [-1, 0, 1, 2, 3, 4],26"y": [1, 0, 0, 1, 2, 3],27}28)29res = pl.sql(30"""31SELECT32x >= ALL(df.y) AS "All Geq",33x > ALL(df.y) AS "All G",34x < ALL(df.y) AS "All L",35x <= ALL(df.y) AS "All Leq",36x >= ANY(df.y) AS "Any Geq",37x > ANY(df.y) AS "Any G",38x < ANY(df.y) AS "Any L",39x <= ANY(df.y) AS "Any Leq",40x == ANY(df.y) AS "Any eq",41x != ANY(df.y) AS "Any Neq",42FROM df43""",44).collect()4546assert res.to_dict(as_series=False) == {47"All Geq": [0, 0, 0, 0, 1, 1],48"All G": [0, 0, 0, 0, 0, 1],49"All L": [1, 0, 0, 0, 0, 0],50"All Leq": [1, 1, 0, 0, 0, 0],51"Any Geq": [0, 1, 1, 1, 1, 1],52"Any G": [0, 0, 1, 1, 1, 1],53"Any L": [1, 1, 1, 1, 0, 0],54"Any Leq": [1, 1, 1, 1, 1, 0],55"Any eq": [0, 1, 1, 1, 1, 0],56"Any Neq": [1, 0, 0, 0, 0, 1],57}585960@pytest.mark.parametrize(61("data", "schema"),62[63({"x": [1, 2, 3, 4]}, None),64({"x": [9, 8, 7, 6]}, {"x": pl.Int8}),65({"x": ["aa", "bb"]}, {"x": pl.Struct}),66({"x": [None, None], "y": [None, None]}, {"x": pl.Date, "y": pl.Float64}),67],68)69def test_boolean_where_clauses(70data: dict[str, Any], schema: dict[str, DataType] | None71) -> None:72df = pl.DataFrame(data=data, schema=schema)73empty_df = df.clear()7475for true in ("TRUE", "1=1", "2 == 2", "'xx' = 'xx'", "TRUE AND 1=1"):76assert_frame_equal(df, df.sql(f"SELECT * FROM self WHERE {true}"))7778for false in ("false", "1!=1", "2 != 2", "'xx' != 'xx'", "FALSE OR 1!=1"):79assert_frame_equal(empty_df, df.sql(f"SELECT * FROM self WHERE {false}"))808182def test_count() -> None:83df = pl.DataFrame(84{85"a": [1, 2, 3, 4, 5],86"b": [1, 1, 22, 22, 333],87"c": [1, 1, None, None, 2],88}89)90res = df.sql(91"""92SELECT93-- count94COUNT(a) AS count_a,95COUNT(b) AS count_b,96COUNT(c) AS count_c,97COUNT(*) AS count_star,98COUNT(NULL) AS count_null,99-- count distinct100COUNT(DISTINCT a) AS count_unique_a,101COUNT(DISTINCT b) AS count_unique_b,102COUNT(DISTINCT c) AS count_unique_c,103COUNT(DISTINCT NULL) AS count_unique_null,104FROM self105""",106)107assert res.to_dict(as_series=False) == {108"count_a": [5],109"count_b": [5],110"count_c": [3],111"count_star": [5],112"count_null": [0],113"count_unique_a": [5],114"count_unique_b": [3],115"count_unique_c": [2],116"count_unique_null": [0],117}118119df = pl.DataFrame({"x": [None, None, None]})120res = df.sql(121"""122SELECT123COUNT(x) AS count_x,124COUNT(*) AS count_star,125COUNT(DISTINCT x) AS count_unique_x126FROM self127"""128)129assert res.to_dict(as_series=False) == {130"count_x": [0],131"count_star": [3],132"count_unique_x": [0],133}134135136def test_distinct() -> None:137df = pl.DataFrame(138{139"a": [1, 1, 1, 2, 2, 3],140"b": [1, 2, 3, 4, 5, 6],141}142)143ctx = pl.SQLContext(register_globals=True, eager=True)144res1 = ctx.execute("SELECT DISTINCT a FROM df ORDER BY a DESC")145assert_frame_equal(146left=df.select("a").unique().sort(by="a", descending=True),147right=res1,148)149150res2 = ctx.execute(151"""152SELECT DISTINCT153a * 2 AS two_a,154b / 2 AS half_b155FROM df156ORDER BY two_a ASC, half_b DESC157""",158)159assert res2.to_dict(as_series=False) == {160"two_a": [2, 2, 4, 6],161"half_b": [1, 0, 2, 3],162}163164# test unregistration165ctx.unregister("df")166with pytest.raises(SQLInterfaceError, match="relation 'df' was not found"):167ctx.execute("SELECT * FROM df")168169170def test_frame_sql_globals_error() -> None:171df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})172df2 = pl.DataFrame({"a": [2, 3, 4], "b": [7, 6, 5]}) # noqa: F841173174query = """175SELECT df1.a, df2.b176FROM df2 JOIN df1 ON df1.a = df2.a177ORDER BY b DESC178"""179with pytest.raises(SQLInterfaceError, match="relation.*not found.*"):180df1.sql(query=query)181182res = pl.sql(query=query, eager=True)183assert res.to_dict(as_series=False) == {"a": [2, 3], "b": [7, 6]}184185186def test_in_no_ops_11946() -> None:187lf = pl.LazyFrame(188[189{"i1": 1},190{"i1": 2},191{"i1": 3},192]193)194out = lf.sql(195query="SELECT * FROM frame_data WHERE i1 in (1, 3)",196table_name="frame_data",197).collect()198assert out.to_dict(as_series=False) == {"i1": [1, 3]}199200201def test_limit_offset() -> None:202n_values = 11203lf = pl.LazyFrame({"a": range(n_values), "b": reversed(range(n_values))})204ctx = pl.SQLContext(tbl=lf)205206assert ctx.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [207(4, 6),208(5, 5),209(6, 4),210]211for offset, limit in [(0, 3), (1, n_values), (2, 3), (5, 3), (8, 5), (n_values, 1)]:212out = ctx.execute(213f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True214)215assert_frame_equal(out, lf.slice(offset, limit).collect())216assert len(out) == min(limit, n_values - offset)217218219def test_register_context() -> None:220# use as context manager unregisters tables created within each scope221# on exit from that scope; arbitrary levels of nesting are supported.222with pl.SQLContext() as ctx:223_lf1 = pl.LazyFrame({"a": [1, 2, 3], "b": ["m", "n", "o"]})224_lf2 = pl.LazyFrame({"a": [2, 3, 4], "c": ["p", "q", "r"]})225ctx.register_globals()226assert ctx.tables() == ["_lf1", "_lf2"]227228with ctx:229_lf3 = pl.LazyFrame({"a": [3, 4, 5], "b": ["s", "t", "u"]})230_lf4 = pl.LazyFrame({"a": [4, 5, 6], "c": ["v", "w", "x"]})231ctx.register_globals(n=2)232assert ctx.tables() == ["_lf1", "_lf2", "_lf3", "_lf4"]233234assert ctx.tables() == ["_lf1", "_lf2"]235236assert ctx.tables() == []237238239def test_sql_on_compatible_frame_types() -> None:240df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})241242# create various different frame types243dfp = df.to_pandas()244dfa = df.to_arrow()245dfb = dfa.to_batches()[0] # noqa: F841246dfo = PyCapsuleStreamHolder(df) # noqa: F841247248# run polars sql query against all frame types249for dfs in ( # noqa: B007250(df["a"] * 2).rename("c"), # polars series251(dfp["a"] * 2).rename("c"), # pandas series252):253res = pl.sql(254"""255SELECT a, b, SUM(c) AS cc FROM (256SELECT * FROM df -- polars frame257UNION ALL SELECT * FROM dfp -- pandas frame258UNION ALL SELECT * FROM dfa -- pyarrow table259UNION ALL SELECT * FROM dfb -- pyarrow record batch260UNION ALL SELECT * FROM dfo -- arbitrary pycapsule object261) tbl262INNER JOIN dfs ON dfs.c == tbl.b -- join on pandas/polars series263GROUP BY "a", "b"264ORDER BY "a", "b"265"""266).collect()267268expected = pl.DataFrame({"a": [1, 3], "b": [4, 6], "cc": [20, 30]})269assert_frame_equal(left=expected, right=res)270271# register and operate on non-polars frames272for obj in (dfa, dfp):273with pl.SQLContext(obj=obj) as ctx:274res = ctx.execute("SELECT * FROM obj", eager=True)275assert_frame_equal(df, res)276277# don't register all compatible objects278with pytest.raises(SQLInterfaceError, match="relation 'dfp' was not found"):279pl.SQLContext(register_globals=True).execute("SELECT * FROM dfp")280281282def test_nested_cte_column_aliasing() -> None:283# trace through nested CTEs with multiple levels of column & table aliasing284df = pl.sql(285"""286WITH287x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)),288y (m, n) AS (289WITH z(c, d) AS (SELECT a, b FROM x)290SELECT d*2 AS d2, c*3 AS c3 FROM z291)292SELECT n, m FROM y293""",294eager=True,295)296assert df.to_dict(as_series=False) == {297"n": [3, 9],298"m": [4, 8],299}300301302def test_invalid_derived_table_column_aliases() -> None:303values_query = "SELECT * FROM (VALUES (1,2), (3,4))"304305with pytest.raises(306SQLSyntaxError,307match=r"columns \(5\) in alias 'tbl' does not match .* the table/query \(2\)",308):309pl.sql(f"{values_query} AS tbl(a, b, c, d, e)")310311assert pl.sql(f"{values_query} tbl", eager=True).rows() == [(1, 2), (3, 4)]312313314def test_values_clause_table_registration() -> None:315with pl.SQLContext(frames=None, eager=True) as ctx:316# initially no tables are registered317assert ctx.tables() == []318319# confirm that VALUES clause derived table is registered, post-query320res1 = ctx.execute("SELECT * FROM (VALUES (-1,1)) AS tbl(x, y)")321assert ctx.tables() == ["tbl"]322323# and confirm that we can select from it by the registered name324res2 = ctx.execute("SELECT x, y FROM tbl")325for res in (res1, res2):326assert res.to_dict(as_series=False) == {"x": [-1], "y": [1]}327328329def test_read_csv(tmp_path: Path) -> None:330# check empty string vs null, parsing of dates, etc331df = pl.DataFrame(332{333"label": ["lorem", None, "", "ipsum"],334"num": [-1, None, 0, 1],335"dt": [336date(1969, 7, 5),337date(1999, 12, 31),338date(2077, 10, 10),339None,340],341}342)343csv_target = tmp_path / "test_sql_read.csv"344df.write_csv(csv_target)345346res = pl.sql(f"SELECT * FROM read_csv('{csv_target}')").collect()347assert_frame_equal(df, res)348349with pytest.raises(350SQLSyntaxError,351match="`read_csv` expects a single file path; found 3 arguments",352):353pl.sql("SELECT * FROM read_csv('a','b','c')")354355356def test_global_variable_inference_17398() -> None:357users = pl.DataFrame({"id": "1"})358359res = pl.sql(360query="""361WITH user_by_email AS (SELECT id FROM users)362SELECT * FROM user_by_email363""",364eager=True,365)366assert_frame_equal(res, users)367368369@pytest.mark.parametrize(370"query",371[372"SELECT invalid_column FROM self",373"SELECT key, invalid_column FROM self",374"SELECT invalid_column * 2 FROM self",375"SELECT * FROM self ORDER BY invalid_column",376"SELECT * FROM self WHERE invalid_column = 200",377"SELECT * FROM self WHERE invalid_column = '200'",378"SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column",379],380)381def test_invalid_cols(query: str) -> None:382df = pl.DataFrame(383{384"key": ["xx", "xx", "yy"],385"n": ["100", "200", "300"],386}387)388with pytest.raises(ColumnNotFoundError, match="invalid_column"):389df.sql(query)390391392@pytest.mark.parametrize("filter_expr", ["", "WHERE 1 = 1", "WHERE a == 1 OR a != 1"])393@pytest.mark.parametrize("order_expr", ["", "ORDER BY 1", "ORDER BY a"])394def test_select_output_heights_20058_21084(filter_expr: str, order_expr: str) -> None:395df = pl.DataFrame({"a": [1, 2, 3]})396397# Queries that maintain original height398399assert_frame_equal(400df.sql(f"SELECT 1 as a FROM self {filter_expr} {order_expr}").cast(pl.Int64),401pl.select(a=pl.Series([1, 1, 1])),402)403404assert_frame_equal(405df.sql(f"SELECT 1 + 1 as a, 1 as b FROM self {filter_expr} {order_expr}").cast(406pl.Int64407),408pl.DataFrame({"a": [2, 2, 2], "b": [1, 1, 1]}),409)410411# Queries that aggregate to unit height412413assert_frame_equal(414df.sql(f"SELECT COUNT(*) as a FROM self {filter_expr} {order_expr}").cast(415pl.Int64416),417pl.DataFrame({"a": 3}),418)419420assert_frame_equal(421df.sql(422f"SELECT COUNT(*) as a, 1 as b FROM self {filter_expr} {order_expr}"423).cast(pl.Int64),424pl.DataFrame({"a": 3, "b": 1}),425)426427assert_frame_equal(428df.sql(429f"SELECT FIRST(a) as a, 1 as b FROM self {filter_expr} {order_expr}"430).cast(pl.Int64),431pl.DataFrame({"a": 1, "b": 1}),432)433434assert_frame_equal(435df.sql(f"SELECT SUM(a) as a, 1 as b FROM self {filter_expr} {order_expr}").cast(436pl.Int64437),438pl.DataFrame({"a": 6, "b": 1}),439)440441assert_frame_equal(442df.sql(443f"SELECT FIRST(1) as a, 1 as b FROM self {filter_expr} {order_expr}"444).cast(pl.Int64),445pl.DataFrame({"a": 1, "b": 1}),446)447448assert_frame_equal(449df.sql(450f"SELECT FIRST(1) + 1 as a, 1 as b FROM self {filter_expr} {order_expr}"451).cast(pl.Int64),452pl.DataFrame({"a": 2, "b": 1}),453)454455assert_frame_equal(456df.sql(457f"SELECT FIRST(1 + 1) as a, 1 as b FROM self {filter_expr} {order_expr}"458).cast(pl.Int64),459pl.DataFrame({"a": 2, "b": 1}),460)461462463def test_select_explode_height_filter_order_by() -> None:464# Note: `unnest()` from SQL equates to `Expr.explode()`465df = pl.DataFrame(466{467"list_long": [[1, 2, 3], [4, 5, 6]],468"sort_key": [2, 1],469"filter_mask": [False, True],470"filter_mask_all_true": True,471}472)473474# Height of unnest is larger than height of sort_key, the sort_key is475# extended with NULLs.476477assert_frame_equal(478df.sql("SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key"),479pl.Series("list", [2, 1, 3, 4, 5, 6]).to_frame(),480)481482assert_frame_equal(483df.sql(484"SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key NULLS FIRST"485),486pl.Series("list", [3, 4, 5, 6, 2, 1]).to_frame(),487)488489# Literals are broadcasted to output height of UNNEST:490assert_frame_equal(491df.sql("SELECT UNNEST(list_long) as list, 1 as x FROM self ORDER BY sort_key"),492pl.select(pl.Series("list", [2, 1, 3, 4, 5, 6]), x=1),493)494495# Note: Filter applies before projections in SQL496assert_frame_equal(497df.sql(498"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask ORDER BY sort_key"499),500pl.Series("list", [4, 5, 6]).to_frame(),501)502503assert_frame_equal(504df.sql(505"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask_all_true ORDER BY sort_key"506),507pl.Series("list", [2, 1, 3, 4, 5, 6]).to_frame(),508)509510511@pytest.mark.parametrize(512("query", "result"),513[514(515"""SELECT a, COUNT(*) OVER (PARTITION BY a) AS b FROM self""",516[3, 3, 3, 1, 3, 3, 3],517),518(519"""SELECT a, COUNT() OVER (PARTITION BY a) AS b FROM self""",520[3, 3, 3, 1, 3, 3, 3],521),522(523"""SELECT a, COUNT(i) OVER (PARTITION BY a) AS b FROM self""",524[3, 3, 3, 1, 1, 1, 1],525),526(527"""SELECT a, COUNT(DISTINCT i) OVER (PARTITION BY a) AS b FROM self""",528[2, 2, 2, 1, 1, 1, 1],529),530],531)532def test_count_partition_22665(query: str, result: list[Any]) -> None:533df = pl.DataFrame(534{535"a": [1, 1, 1, 2, 3, 3, 3],536"i": [0, 0, 1, 2, 3, None, None],537}538)539out = df.sql(query).select("b")540expected = pl.DataFrame({"b": result}).cast({"b": pl.UInt32})541assert_frame_equal(out, expected)542543544