Path: blob/main/py-polars/tests/unit/sql/test_table_operations.py
8391 views
from __future__ import annotations12import re3from datetime import date4from typing import TYPE_CHECKING56import pytest78import polars as pl9from polars.exceptions import SQLInterfaceError10from polars.testing import assert_frame_equal1112if TYPE_CHECKING:13from pathlib import Path141516@pytest.fixture17def test_frame() -> pl.LazyFrame:18return pl.LazyFrame(19{20"x": [1, 2, 3],21"y": ["aaa", "bbb", "ccc"],22"z": [date(2000, 12, 31), date(1978, 11, 15), date(2077, 10, 20)],23},24schema_overrides={"x": pl.UInt8},25)262728def test_create_table() -> None:29with pl.SQLContext() as ctx:30# test all three ways of creating a new table31ctx.execute("CREATE TABLE tbl1(colx VARCHAR, coly DATE, colz ARRAY<DOUBLE>)")32ctx.execute("CREATE TABLE tbl2 AS SELECT * FROM tbl1")33ctx.execute("CREATE TABLE tbl3 LIKE tbl2")34df = ctx.execute("SELECT * FROM tbl3", eager=True)3536df_expected = pl.DataFrame(37schema={38"colx": pl.String,39"coly": pl.Date,40"colz": pl.List(pl.Float64),41}42)43assert_frame_equal(df_expected, df)444546def test_create_table_from_file_io(io_files_path: Path) -> None:47foods_csv = io_files_path / "foods*.csv"48with pl.SQLContext() as ctx:49ctx.execute(50query=f"""51CREATE TABLE foods AS52SELECT * FROM READ_CSV('{foods_csv}')53""",54eager=True,55)56df = ctx.execute("SELECT * FROM foods", eager=True)57assert df.schema == {58"category": pl.String,59"calories": pl.Int64,60"fats_g": pl.Float64,61"sugars_g": pl.Int64,62}63assert df.shape == (135, 4)646566@pytest.mark.parametrize(67("delete_constraint", "expected_ids"),68[69# basic constraints70("WHERE id = 200", {100, 300}),71("WHERE id = 200 OR id = 300", {100}),72("WHERE id IN (200, 300, 400)", {100}),73("WHERE id NOT IN (200, 300, 400)", {200, 300}),74# more involved constraints75("WHERE EXTRACT(year FROM dt) >= 2000", {200}),76# null-handling (in the data)77("WHERE v1 < 0", {100, 300}),78("WHERE v1 > 0", {200, 300}),79# null handling (in the constraint)80("WHERE v1 IS NULL", {100, 200}),81("WHERE v1 IS NOT NULL", {300}),82# boolean handling (delete all/none)83("WHERE FALSE", {100, 200, 300}),84("WHERE TRUE", set()),85# no constraint; equivalent to TRUNCATE (drop all rows)86("", set()),87],88)89def test_delete_clause(delete_constraint: str, expected_ids: set[int]) -> None:90df = pl.DataFrame(91{92"id": [100, 200, 300],93"dt": [date(2020, 10, 10), date(1999, 1, 2), date(2001, 7, 5)],94"v1": [3.5, -4.0, None],95"v2": [10.0, 2.5, -1.5],96}97)98res = df.sql(f"DELETE FROM self {delete_constraint}")99assert set(res["id"]) == expected_ids100101102def test_drop_table(test_frame: pl.LazyFrame) -> None:103# 'drop' completely removes the table from sql context104expected = pl.DataFrame()105106with pl.SQLContext(frame=test_frame, eager=True) as ctx:107res = ctx.execute("DROP TABLE frame")108assert_frame_equal(res, expected)109110with pytest.raises(SQLInterfaceError, match="'frame' was not found"):111ctx.execute("SELECT * FROM frame")112113114def test_explain_query(test_frame: pl.LazyFrame) -> None:115# 'explain' returns the query plan for the given sql116with pl.SQLContext(frame=test_frame) as ctx:117plan = (118ctx.execute("EXPLAIN SELECT * FROM frame")119.select(pl.col("Logical Plan").str.join())120.collect()121.item()122)123assert (124re.search(125pattern=r"PROJECT.+?COLUMNS",126string=plan,127flags=re.IGNORECASE,128)129is not None130)131132133def test_show_tables(test_frame: pl.LazyFrame) -> None:134# 'show tables' lists all tables registered with the sql context in sorted order135with pl.SQLContext(136tbl3=test_frame,137tbl2=test_frame,138tbl1=test_frame,139) as ctx:140res = ctx.execute("SHOW TABLES").collect()141assert_frame_equal(res, pl.DataFrame({"name": ["tbl1", "tbl2", "tbl3"]}))142143144@pytest.mark.parametrize(145"truncate_sql",146[147"TRUNCATE TABLE frame",148"TRUNCATE frame",149],150)151def test_truncate_table(truncate_sql: str, test_frame: pl.LazyFrame) -> None:152# 'truncate' preserves the table, but optimally drops all rows within it153expected = pl.DataFrame(schema=test_frame.collect_schema())154155with pl.SQLContext(frame=test_frame, eager=True) as ctx:156res = ctx.execute(truncate_sql)157assert_frame_equal(res, expected)158159res = ctx.execute("SELECT * FROM frame")160assert_frame_equal(res, expected)161162163