Path: blob/main/py-polars/tests/unit/io/database/test_read.py
6939 views
from __future__ import annotations12import os3import sqlite34import sys5from contextlib import suppress6from datetime import date7from pathlib import Path8from types import GeneratorType9from typing import TYPE_CHECKING, Any, NamedTuple, cast10from unittest.mock import Mock, patch1112with suppress(ModuleNotFoundError): # not available on windows13import adbc_driver_sqlite.dbapi14import pyarrow as pa15import pytest16import sqlalchemy17from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text18from sqlalchemy.orm import sessionmaker19from sqlalchemy.sql.expression import cast as alchemy_cast2021import polars as pl22from polars._utils.various import parse_version23from polars.exceptions import DuplicateError, UnsuitableSQLError24from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY25from polars.testing import assert_frame_equal, assert_series_equal2627if TYPE_CHECKING:28from polars._typing import (29ConnectionOrCursor,30DbReadEngine,31SchemaDefinition,32SchemaDict,33)343536def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any:37args = tuple(str(a) if isinstance(a, Path) else a for a in args)38return adbc_driver_sqlite.dbapi.connect(*args, **kwargs)394041class MockConnection:42"""Mock connection class for databases we can't test in CI."""4344def __init__(45self,46driver: str,47batch_size: int | None,48exact_batch_size: bool,49test_data: pa.Table,50repeat_batch_calls: bool,51) -> None:52self.__class__.__module__ = driver53self._cursor = MockCursor(54repeat_batch_calls=repeat_batch_calls,55exact_batch_size=exact_batch_size,56batched=(batch_size is not None),57test_data=test_data,58)5960def close(self) -> None:61pass6263def cursor(self) -> Any:64return self._cursor656667class MockCursor:68"""Mock cursor class for databases we can't test in CI."""6970def __init__(71self,72batched: bool,73exact_batch_size: bool,74test_data: pa.Table,75repeat_batch_calls: bool,76) -> None:77self.resultset = MockResultSet(78test_data=test_data,79batched=batched,80exact_batch_size=exact_batch_size,81repeat_batch_calls=repeat_batch_calls,82)83self.exact_batch_size = exact_batch_size84self.called: list[str] = []85self.batched = batched86self.n_calls = 18788def __getattr__(self, name: str) -> Any:89if "fetch" in name:90self.called.append(name)91return self.resultset92super().__getattr__(name) # type: ignore[misc]9394def close(self) -> Any:95pass9697def execute(self, query: str) -> Any:98return self99100101class MockResultSet:102"""Mock resultset class for databases we can't test in CI."""103104def __init__(105self,106test_data: pa.Table,107batched: bool,108exact_batch_size: bool,109repeat_batch_calls: bool = False,110) -> None:111self.test_data = test_data112self.repeat_batched_calls = repeat_batch_calls113self.exact_batch_size = exact_batch_size114self.batched = batched115self.n_calls = 1116117def __call__(self, *args: Any, **kwargs: Any) -> Any:118if not self.exact_batch_size:119assert len(args) == 0120if self.repeat_batched_calls:121res = self.test_data[: None if self.n_calls else 0]122self.n_calls -= 1123else:124res = iter((self.test_data,))125return res126127128class DatabaseReadTestParams(NamedTuple):129"""Clarify read test params."""130131read_method: str132connect_using: Any133expected_dtypes: SchemaDefinition134expected_dates: list[date | str]135schema_overrides: SchemaDict | None = None136batch_size: int | None = None137138139class ExceptionTestParams(NamedTuple):140"""Clarify exception test params."""141142read_method: str143query: str | list[str]144protocol: Any145errclass: type[Exception]146errmsg: str147engine: str | None = None148execute_options: dict[str, Any] | None = None149pre_execution_query: str | list[str] | None = None150kwargs: dict[str, Any] | None = None151152153@pytest.mark.write_disk154@pytest.mark.parametrize(155(156"read_method",157"connect_using",158"expected_dtypes",159"expected_dates",160"schema_overrides",161"batch_size",162),163[164pytest.param(165*DatabaseReadTestParams(166read_method="read_database_uri",167connect_using="connectorx",168expected_dtypes={169"id": pl.UInt8,170"name": pl.String,171"value": pl.Float64,172"date": pl.Date,173},174expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],175schema_overrides={"id": pl.UInt8},176),177id="uri: connectorx",178),179pytest.param(180*DatabaseReadTestParams(181read_method="read_database_uri",182connect_using="adbc",183expected_dtypes={184"id": pl.UInt8,185"name": pl.String,186"value": pl.Float64,187"date": pl.String,188},189expected_dates=["2020-01-01", "2021-12-31"],190schema_overrides={"id": pl.UInt8},191),192marks=pytest.mark.skipif(193sys.platform == "win32",194reason="adbc_driver_sqlite not available on Windows",195),196id="uri: adbc",197),198pytest.param(199*DatabaseReadTestParams(200read_method="read_database",201connect_using=lambda path: sqlite3.connect(path, detect_types=True),202expected_dtypes={203"id": pl.UInt8,204"name": pl.String,205"value": pl.Float32,206"date": pl.Date,207},208expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],209schema_overrides={"id": pl.UInt8, "value": pl.Float32},210),211id="conn: sqlite3",212),213pytest.param(214*DatabaseReadTestParams(215read_method="read_database",216connect_using=lambda path: sqlite3.connect(path, detect_types=True),217expected_dtypes={218"id": pl.Int32,219"name": pl.String,220"value": pl.Float32,221"date": pl.Date,222},223expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],224schema_overrides={"id": pl.Int32, "value": pl.Float32},225batch_size=1,226),227id="conn: sqlite3",228),229pytest.param(230*DatabaseReadTestParams(231read_method="read_database",232connect_using=lambda path: create_engine(233f"sqlite:///{path}",234connect_args={"detect_types": sqlite3.PARSE_DECLTYPES},235).connect(),236expected_dtypes={237"id": pl.Int64,238"name": pl.String,239"value": pl.Float64,240"date": pl.Date,241},242expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],243),244id="conn: sqlalchemy",245),246pytest.param(247*DatabaseReadTestParams(248read_method="read_database",249connect_using=adbc_sqlite_connect,250expected_dtypes={251"id": pl.Int64,252"name": pl.String,253"value": pl.Float64,254"date": pl.String,255},256expected_dates=["2020-01-01", "2021-12-31"],257),258marks=pytest.mark.skipif(259sys.platform == "win32",260reason="adbc_driver_sqlite not available on Windows",261),262id="conn: adbc (fetchall)",263),264pytest.param(265*DatabaseReadTestParams(266read_method="read_database",267connect_using=adbc_sqlite_connect,268expected_dtypes={269"id": pl.Int64,270"name": pl.String,271"value": pl.Float64,272"date": pl.String,273},274expected_dates=["2020-01-01", "2021-12-31"],275batch_size=1,276),277marks=pytest.mark.skipif(278sys.platform == "win32",279reason="adbc_driver_sqlite not available on Windows",280),281id="conn: adbc (batched)",282),283],284)285def test_read_database(286read_method: str,287connect_using: Any,288expected_dtypes: dict[str, pl.DataType],289expected_dates: list[date | str],290schema_overrides: SchemaDict | None,291batch_size: int | None,292tmp_sqlite_db: Path,293) -> None:294if read_method == "read_database_uri":295connect_using = cast("DbReadEngine", connect_using)296# instantiate the connection ourselves, using connectorx/adbc297df = pl.read_database_uri(298uri=f"sqlite:///{tmp_sqlite_db}",299query="SELECT * FROM test_data",300engine=connect_using,301schema_overrides=schema_overrides,302)303df_empty = pl.read_database_uri(304uri=f"sqlite:///{tmp_sqlite_db}",305query="SELECT * FROM test_data WHERE name LIKE '%polars%'",306engine=connect_using,307schema_overrides=schema_overrides,308)309elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]:310# externally instantiated adbc connections311with connect_using(tmp_sqlite_db) as conn:312df = pl.read_database(313connection=conn,314query="SELECT * FROM test_data",315schema_overrides=schema_overrides,316batch_size=batch_size,317)318df_empty = pl.read_database(319connection=conn,320query="SELECT * FROM test_data WHERE name LIKE '%polars%'",321schema_overrides=schema_overrides,322batch_size=batch_size,323)324else:325# other user-supplied connections326df = pl.read_database(327connection=connect_using(tmp_sqlite_db),328query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'",329schema_overrides=schema_overrides,330batch_size=batch_size,331)332df_empty = pl.read_database(333connection=connect_using(tmp_sqlite_db),334query="SELECT * FROM test_data WHERE name LIKE '%polars%'",335schema_overrides=schema_overrides,336batch_size=batch_size,337)338339# validate the expected query return (data and schema)340assert df.schema == expected_dtypes341assert df.shape == (2, 4)342assert df["date"].to_list() == expected_dates343344# note: 'cursor.description' is not reliable when no query345# data is returned, so no point comparing expected dtypes346assert df_empty.columns == ["id", "name", "value", "date"]347assert df_empty.shape == (0, 4)348assert df_empty["date"].to_list() == []349350351def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None:352# various flavours of alchemy connection353alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")354alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()355alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()356357t = Table("test_data", MetaData(), autoload_with=alchemy_engine)358359# establish sqlalchemy "selectable" and validate usage360selectable_query = select(361alchemy_cast(func.strftime("%Y", t.c.date), Integer).label("year"),362t.c.name,363t.c.value,364).where(t.c.value < 0)365366expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})367368for conn in (alchemy_session, alchemy_engine, alchemy_conn):369assert_frame_equal(370pl.read_database(selectable_query, connection=conn),371expected,372)373374batches = list(375pl.read_database(376selectable_query,377connection=conn,378iter_batches=True,379batch_size=1,380)381)382assert len(batches) == 1383assert_frame_equal(batches[0], expected)384385386def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None:387# various flavours of alchemy connection388alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")389alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()390alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()391392# establish sqlalchemy "textclause" and validate usage393textclause_query = text(394"""395SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value396FROM test_data397WHERE value < 0398"""399)400401expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})402403for conn in (alchemy_session, alchemy_engine, alchemy_conn):404assert_frame_equal(405pl.read_database(textclause_query, connection=conn),406expected,407)408409batches = list(410pl.read_database(411textclause_query,412connection=conn,413iter_batches=True,414batch_size=1,415)416)417assert len(batches) == 1418assert_frame_equal(batches[0], expected)419420421@pytest.mark.parametrize(422("param", "param_value"),423[424(":n", {"n": 0}),425("?", (0,)),426("?", [0]),427],428)429def test_read_database_parameterised(430param: str, param_value: Any, tmp_sqlite_db: Path431) -> None:432# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs433alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")434alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()435alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()436raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)437438# establish parameterised queries and validate usage439query = """440SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value441FROM test_data442WHERE value < {n}443"""444expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})445446for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):447if conn is alchemy_session and param == "?":448continue # alchemy session.execute() doesn't support positional params449if parse_version(sqlalchemy.__version__) < (2, 0) and param == ":n":450continue # skip for older sqlalchemy versions451452assert_frame_equal(453expected_frame,454pl.read_database(455query.format(n=param),456connection=conn,457execute_options={"parameters": param_value},458),459)460461462@pytest.mark.parametrize(463("param", "param_value"),464[465pytest.param(466":n",467pa.Table.from_pydict({"n": [0]}),468marks=pytest.mark.skip(469reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"470),471),472pytest.param(473":n",474{"n": 0},475marks=pytest.mark.skip(476reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262",477),478),479("?", pa.Table.from_pydict({"data": [0]})),480("?", pl.DataFrame({"data": [0]})),481("?", pl.Series([{"data": 0}])),482("?", (0,)),483("?", [0]),484],485)486@pytest.mark.skipif(487sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"488)489def test_read_database_parameterised_adbc(490param: str, param_value: Any, tmp_sqlite_db: Path491) -> None:492# establish parameterised queries and validate usage493query = """494SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value495FROM test_data496WHERE value < {n}497"""498expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})499500# ADBC will complain in pytest if the connection isn't closed501with adbc_driver_sqlite.dbapi.connect(str(tmp_sqlite_db)) as conn:502assert_frame_equal(503expected_frame,504pl.read_database(505query.format(n=param),506connection=conn,507execute_options={"parameters": param_value},508),509)510511512@pytest.mark.parametrize(513("params", "param_value"),514[515([":lo", ":hi"], {"lo": 90, "hi": 100}),516(["?", "?"], (90, 100)),517(["?", "?"], [90, 100]),518],519)520def test_read_database_parameterised_multiple(521params: list[str], param_value: Any, tmp_sqlite_db: Path522) -> None:523param_1, param_2 = params524# establish parameterised queries and validate usage525query = """526SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value527FROM test_data528WHERE value BETWEEN {param_1} AND {param_2}529"""530expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})531532# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs533alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")534alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()535alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()536raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)537for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):538if alchemy_session is conn and param_1 == "?":539continue # alchemy session.execute() doesn't support positional params540if parse_version(sqlalchemy.__version__) < (2, 0) and isinstance(541param_value, dict542):543continue # skip for older sqlalchemy versions544545assert_frame_equal(546expected_frame,547pl.read_database(548query.format(param_1=param_1, param_2=param_2),549connection=conn,550execute_options={"parameters": param_value},551),552)553554555@pytest.mark.parametrize(556("params", "param_value"),557[558pytest.param(559[":lo", ":hi"],560{"lo": 90, "hi": 100},561marks=pytest.mark.skip(562reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"563),564),565(["?", "?"], pa.Table.from_pydict({"data_1": [90], "data_2": [100]})),566(["?", "?"], pl.DataFrame({"data_1": [90], "data_2": [100]})),567(["?", "?"], pl.Series([{"data_1": 90, "data_2": 100}])),568(["?", "?"], (90, 100)),569(["?", "?"], [90, 100]),570],571)572@pytest.mark.skipif(573sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"574)575def test_read_database_parameterised_multiple_adbc(576params: list[str], param_value: Any, tmp_sqlite_db: Path577) -> None:578param_1, param_2 = params579# establish parameterised queries and validate usage580query = """581SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value582FROM test_data583WHERE value BETWEEN {param_1} AND {param_2}584"""585expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})586587# ADBC will complain in pytest if the connection isn't closed588with adbc_driver_sqlite.dbapi.connect(str(tmp_sqlite_db)) as conn:589assert_frame_equal(590expected_frame,591pl.read_database(592query.format(param_1=param_1, param_2=param_2),593connection=conn,594execute_options={"parameters": param_value},595),596)597598599@pytest.mark.parametrize(600("param", "param_value"),601[602pytest.param(603":n",604pa.Table.from_pydict({"n": [0]}),605marks=pytest.mark.skip(606reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"607),608),609pytest.param(610":n",611{"n": 0},612marks=pytest.mark.skip(613reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262",614),615),616("?", pa.Table.from_pydict({"data": [0]})),617("?", pl.DataFrame({"data": [0]})),618("?", pl.Series([{"data": 0}])),619("?", (0,)),620("?", [0]),621],622)623@pytest.mark.skipif(624sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"625)626def test_read_database_uri_parameterised(627param: str, param_value: Any, tmp_sqlite_db: Path628) -> None:629alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")630uri = alchemy_engine.url.render_as_string(hide_password=False)631query = """632SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value633FROM test_data634WHERE value < {n}635"""636expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})637638# test URI read method (adbc only)639assert_frame_equal(640expected_frame,641pl.read_database_uri(642query.format(n=param),643uri=uri,644engine="adbc",645execute_options={"parameters": param_value},646),647)648649# no connectorx support for execute_options650with pytest.raises(651ValueError,652match="connectorx.*does not support.*execute_options",653):654pl.read_database_uri(655query.format(n=":n"),656uri=uri,657engine="connectorx",658execute_options={"parameters": (":n", {"n": 0})},659)660661662@pytest.mark.parametrize(663("params", "param_value"),664[665pytest.param(666[":lo", ":hi"],667{"lo": 90, "hi": 100},668marks=pytest.mark.xfail(669reason="Named binding not supported. See https://github.com/apache/arrow-adbc/issues/3262",670strict=True,671),672),673(["?", "?"], pa.Table.from_pydict({"data_1": [90], "data_2": [100]})),674(["?", "?"], pl.DataFrame({"data_1": [90], "data_2": [100]})),675(["?", "?"], pl.Series([{"data_1": 90, "data_2": 100}])),676(["?", "?"], (90, 100)),677(["?", "?"], [90, 100]),678],679)680@pytest.mark.skipif(681sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"682)683def test_read_database_uri_parameterised_multiple(684params: list[str], param_value: Any, tmp_sqlite_db: Path685) -> None:686param_1, param_2 = params687alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")688uri = alchemy_engine.url.render_as_string(hide_password=False)689query = """690SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value691FROM test_data692WHERE value BETWEEN {param_1} AND {param_2}693"""694expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})695696# test URI read method (ADBC only)697assert_frame_equal(698expected_frame,699pl.read_database_uri(700query.format(param_1=param_1, param_2=param_2),701uri=uri,702engine="adbc",703execute_options={"parameters": param_value},704),705)706707# no connectorx support for execute_options708with pytest.raises(709ValueError,710match="connectorx.*does not support.*execute_options",711):712pl.read_database_uri(713query.format(param_1="?", param_2="?"),714uri=uri,715engine="connectorx",716execute_options={"parameters": (90, 100)},717)718719720@pytest.mark.parametrize(721("driver", "batch_size", "iter_batches", "expected_call"),722[723("snowflake", None, False, "fetch_arrow_all"),724("snowflake", 10_000, False, "fetch_arrow_all"),725("snowflake", 10_000, True, "fetch_arrow_batches"),726("databricks", None, False, "fetchall_arrow"),727("databricks", 25_000, False, "fetchall_arrow"),728("databricks", 25_000, True, "fetchmany_arrow"),729("turbodbc", None, False, "fetchallarrow"),730("turbodbc", 50_000, False, "fetchallarrow"),731("turbodbc", 50_000, True, "fetcharrowbatches"),732("adbc_driver_postgresql", None, False, "fetch_arrow_table"),733("adbc_driver_postgresql", 75_000, False, "fetch_arrow_table"),734("adbc_driver_postgresql", 75_000, True, "fetch_arrow_table"),735],736)737def test_read_database_mocked(738driver: str, batch_size: int | None, iter_batches: bool, expected_call: str739) -> None:740# since we don't have access to snowflake/databricks/etc from CI we741# mock them so we can check that we're calling the expected methods742arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow()743744reg = ARROW_DRIVER_REGISTRY.get(driver, {}) # type: ignore[var-annotated]745exact_batch_size = reg.get("exact_batch_size", False)746repeat_batch_calls = reg.get("repeat_batch_calls", False)747748mc = MockConnection(749driver,750batch_size,751test_data=arrow,752repeat_batch_calls=repeat_batch_calls,753exact_batch_size=exact_batch_size, # type: ignore[arg-type]754)755res = pl.read_database(756query="SELECT * FROM test_data",757connection=mc,758iter_batches=iter_batches,759batch_size=batch_size,760)761if iter_batches:762assert isinstance(res, GeneratorType)763res = pl.concat(res)764765res = cast(pl.DataFrame, res)766assert expected_call in mc.cursor().called767assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")]768769770@pytest.mark.parametrize(771(772"read_method",773"query",774"protocol",775"errclass",776"errmsg",777"engine",778"execute_options",779"pre_execution_query",780"kwargs",781),782[783pytest.param(784*ExceptionTestParams(785read_method="read_database_uri",786query="SELECT * FROM test_data",787protocol="sqlite",788errclass=ValueError,789errmsg="engine must be one of {'connectorx', 'adbc'}, got 'not_an_engine'",790engine="not_an_engine",791),792id="Not an available sql engine",793),794pytest.param(795*ExceptionTestParams(796read_method="read_database_uri",797query=["SELECT * FROM test_data", "SELECT * FROM test_data"],798protocol="sqlite",799errclass=ValueError,800errmsg="only a single SQL query string is accepted for adbc",801engine="adbc",802),803id="Unavailable list of queries for adbc",804),805pytest.param(806*ExceptionTestParams(807read_method="read_database_uri",808query="SELECT * FROM test_data",809protocol="mysql",810errclass=ModuleNotFoundError,811errmsg="ADBC 'adbc_driver_mysql' driver not detected.",812engine="adbc",813),814id="Unavailable adbc driver",815),816pytest.param(817*ExceptionTestParams(818read_method="read_database_uri",819query="SELECT * FROM test_data",820protocol=sqlite3.connect(":memory:"),821errclass=TypeError,822errmsg="expected connection to be a URI string",823engine="adbc",824),825id="Invalid connection URI",826),827pytest.param(828*ExceptionTestParams(829read_method="read_database",830query="SELECT * FROM imaginary_table",831protocol=sqlite3.connect(":memory:"),832errclass=sqlite3.OperationalError,833errmsg="no such table: imaginary_table",834),835id="Invalid query (unrecognised table name)",836),837pytest.param(838*ExceptionTestParams(839read_method="read_database",840query="SELECT * FROM imaginary_table",841protocol=sys.getsizeof, # not a connection842errclass=TypeError,843errmsg="Unrecognised connection .* no 'execute' or 'cursor' method",844),845id="Invalid read DB kwargs",846),847pytest.param(848*ExceptionTestParams(849read_method="read_database",850query="/* tag: misc */ INSERT INTO xyz VALUES ('polars')",851protocol=sqlite3.connect(":memory:"),852errclass=UnsuitableSQLError,853errmsg="INSERT statements are not valid 'read' queries",854),855id="Invalid statement type",856),857pytest.param(858*ExceptionTestParams(859read_method="read_database",860query="DELETE FROM xyz WHERE id = 'polars'",861protocol=sqlite3.connect(":memory:"),862errclass=UnsuitableSQLError,863errmsg="DELETE statements are not valid 'read' queries",864),865id="Invalid statement type",866),867pytest.param(868*ExceptionTestParams(869read_method="read_database",870query="SELECT * FROM sqlite_master",871protocol=sqlite3.connect(":memory:"),872errclass=ValueError,873kwargs={"iter_batches": True},874errmsg="Cannot set `iter_batches` without also setting a non-zero `batch_size`",875),876id="Invalid batch_size",877),878pytest.param(879*ExceptionTestParams(880read_method="read_database",881engine="adbc",882query="SELECT * FROM test_data",883protocol=sqlite3.connect(":memory:"),884errclass=TypeError,885errmsg=r"unexpected keyword argument 'partition_on'",886kwargs={"partition_on": "id"},887),888id="Invalid kwargs",889),890pytest.param(891*ExceptionTestParams(892read_method="read_database",893engine="adbc",894query="SELECT * FROM test_data",895protocol="{not:a, valid:odbc_string}",896errclass=ValueError,897errmsg=r"unable to identify string connection as valid ODBC",898),899id="Invalid ODBC string",900),901pytest.param(902*ExceptionTestParams(903read_method="read_database_uri",904query="SELECT * FROM test_data",905protocol="sqlite",906errclass=ValueError,907errmsg="the 'adbc' engine does not support use of `pre_execution_query`",908engine="adbc",909pre_execution_query="SET statement_timeout = 2151",910),911id="Unavailable `pre_execution_query` for adbc",912),913],914)915def test_read_database_exceptions(916read_method: str,917query: str,918protocol: Any,919errclass: type[Exception],920errmsg: str,921engine: DbReadEngine | None,922execute_options: dict[str, Any] | None,923pre_execution_query: str | list[str] | None,924kwargs: dict[str, Any] | None,925) -> None:926if read_method == "read_database_uri":927conn = f"{protocol}://test" if isinstance(protocol, str) else protocol928params = {929"uri": conn,930"query": query,931"engine": engine,932"pre_execution_query": pre_execution_query,933}934else:935params = {"connection": protocol, "query": query}936if execute_options:937params["execute_options"] = execute_options938if kwargs is not None:939params.update(kwargs)940941read_database = getattr(pl, read_method)942with pytest.raises(errclass, match=errmsg):943read_database(**params)944945946@pytest.mark.parametrize(947"query",948[949"SELECT 1, 1 FROM test_data",950'SELECT 1 AS "n", 2 AS "n" FROM test_data',951'SELECT name, value AS "name" FROM test_data',952],953)954def test_read_database_duplicate_column_error(tmp_sqlite_db: Path, query: str) -> None:955alchemy_conn = create_engine(f"sqlite:///{tmp_sqlite_db}").connect()956with pytest.raises(957DuplicateError,958match="column .+ appears more than once in the query/result cursor",959):960pl.read_database(query, connection=alchemy_conn)961962963@pytest.mark.parametrize(964"uri",965[966"fakedb://123:456@account/database/schema?warehouse=warehouse&role=role",967"fakedb://my#%us3r:p433w0rd@not_a_real_host:9999/database",968],969)970def test_read_database_cx_credentials(uri: str) -> None:971with pytest.raises(RuntimeError, match=r"Source.*not supported"):972pl.read_database_uri("SELECT * FROM data", uri=uri, engine="connectorx")973974975@pytest.mark.skipif(976sys.platform == "win32",977reason="kuzu segfaults on windows: https://github.com/pola-rs/polars/actions/runs/12502055945/job/34880479875?pr=20462",978)979@pytest.mark.write_disk980def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None:981import kuzu982983tmp_path.mkdir(exist_ok=True)984if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists():985kuzu_test_db.unlink()986987test_db = str(kuzu_test_db).replace("\\", "/")988989db = kuzu.Database(test_db)990conn = kuzu.Connection(db)991conn.execute("CREATE NODE TABLE User(name STRING, age UINT64, PRIMARY KEY (name))")992conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")993994users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/")995follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/")996997conn.execute(f'COPY User FROM "{users}"')998conn.execute(f'COPY Follows FROM "{follows}"')9991000# basic: single relation1001df1 = pl.read_database(1002query="MATCH (u:User) RETURN u.name, u.age",1003connection=conn,1004)1005assert_frame_equal(1006df1,1007pl.DataFrame(1008{1009"u.name": ["Adam", "Karissa", "Zhang", "Noura"],1010"u.age": [30, 40, 50, 25],1011},1012schema={"u.name": pl.Utf8, "u.age": pl.UInt64},1013),1014)10151016# join: connected edges/relations1017df2 = pl.read_database(1018query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name",1019connection=conn,1020schema_overrides={"f.since": pl.Int16},1021)1022assert_frame_equal(1023df2,1024pl.DataFrame(1025{1026"a.name": ["Adam", "Adam", "Karissa", "Zhang"],1027"f.since": [2020, 2020, 2021, 2022],1028"b.name": ["Karissa", "Zhang", "Zhang", "Noura"],1029},1030schema={"a.name": pl.Utf8, "f.since": pl.Int16, "b.name": pl.Utf8},1031),1032)10331034# empty: no results for the given query1035df3 = pl.read_database(1036query="MATCH (a:User)-[f:Follows]->(b:User) WHERE a.name = '🔎️' RETURN a.name, f.since, b.name",1037connection=conn,1038)1039assert_frame_equal(1040df3,1041pl.DataFrame(1042schema={"a.name": pl.Utf8, "f.since": pl.Int64, "b.name": pl.Utf8}1043),1044)104510461047def test_sqlalchemy_row_init(tmp_sqlite_db: Path) -> None:1048expected_frame = pl.DataFrame(1049{1050"id": [1, 2],1051"name": ["misc", "other"],1052"value": [100.0, -99.5],1053"date": ["2020-01-01", "2021-12-31"],1054}1055)1056alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")1057query = text("SELECT * FROM test_data ORDER BY name")10581059with alchemy_engine.connect() as conn:1060# note: sqlalchemy `Row` is a NamedTuple-like object; it additionally has1061# a `_mapping` attribute that returns a `RowMapping` dict-like object. we1062# validate frame/series init from each flavour of query result.1063query_result = list(conn.execute(query))1064for df in (1065pl.DataFrame(query_result),1066pl.DataFrame([row._mapping for row in query_result]),1067pl.from_records([row._mapping for row in query_result]),1068):1069assert_frame_equal(expected_frame, df)10701071expected_series = expected_frame.to_struct()1072for s in (1073pl.Series(query_result),1074pl.Series([row._mapping for row in query_result]),1075):1076assert_series_equal(expected_series, s)107710781079@patch("polars.io.database._utils.from_arrow")1080@patch("polars.io.database._utils.import_optional")1081def test_read_database_uri_pre_execution_query_success(1082import_mock: Mock, from_arrow_mock: Mock1083) -> None:1084cx_mock = Mock()1085cx_mock.__version__ = "0.4.2"10861087import_mock.return_value = cx_mock10881089pre_execution_query = "SET statement_timeout = 2151"10901091pl.read_database_uri(1092query="SELECT 1",1093uri="mysql://test",1094engine="connectorx",1095pre_execution_query=pre_execution_query,1096)10971098assert (1099cx_mock.read_sql.call_args.kwargs["pre_execution_query"] == pre_execution_query1100)110111021103@patch("polars.io.database._utils.import_optional")1104def test_read_database_uri_pre_execution_not_supported_exception(1105import_mock: Mock,1106) -> None:1107cx_mock = Mock()1108cx_mock.__version__ = "0.4.0"11091110import_mock.return_value = cx_mock11111112with (1113pytest.raises(1114ValueError,1115match="'pre_execution_query' is only supported in connectorx version 0.4.2 or later",1116),1117):1118pl.read_database_uri(1119query="SELECT 1",1120uri="mysql://test",1121engine="connectorx",1122pre_execution_query="SET statement_timeout = 2151",1123)112411251126@patch("polars.io.database._utils.from_arrow")1127@patch("polars.io.database._utils.import_optional")1128def test_read_database_uri_pre_execution_query_not_supported_success(1129import_mock: Mock, from_arrow_mock: Mock1130) -> None:1131cx_mock = Mock()1132cx_mock.__version__ = "0.4.0"11331134import_mock.return_value = cx_mock11351136pl.read_database_uri(1137query="SELECT 1",1138uri="mysql://test",1139engine="connectorx",1140)11411142assert cx_mock.read_sql.call_args.kwargs.get("pre_execution_query") is None114311441145