Path: blob/main/py-polars/tests/unit/io/database/test_read.py
8420 views
from __future__ import annotations12import os3import sqlite34import sys5from contextlib import suppress6from datetime import date7from pathlib import Path8from types import GeneratorType9from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast10from unittest.mock import Mock, patch1112with suppress(ModuleNotFoundError): # not available on windows13import adbc_driver_sqlite.dbapi14import pyarrow as pa15import pytest16import sqlalchemy17from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text18from sqlalchemy.orm import sessionmaker19from sqlalchemy.sql.expression import cast as alchemy_cast2021import polars as pl22from polars._utils.various import parse_version23from polars.exceptions import DuplicateError, UnsuitableSQLError24from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY25from polars.testing import assert_frame_equal, assert_series_equal2627if TYPE_CHECKING:28from polars._typing import (29ConnectionOrCursor,30DbReadEngine,31SchemaDefinition,32SchemaDict,33)343536def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any:37args = tuple(str(a) if isinstance(a, Path) else a for a in args)38return adbc_driver_sqlite.dbapi.connect(*args, **kwargs)394041class MockConnection:42"""Mock connection class for databases we can't test in CI."""4344def __init__(45self,46driver: str,47batch_size: int | None,48exact_batch_size: bool,49test_data: pa.Table,50repeat_batch_calls: bool,51) -> None:52self.__class__.__module__ = driver53self._cursor = MockCursor(54repeat_batch_calls=repeat_batch_calls,55exact_batch_size=exact_batch_size,56batched=(batch_size is not None),57test_data=test_data,58)5960def close(self) -> None:61pass6263def cursor(self) -> Any:64return self._cursor656667class MockCursor:68"""Mock cursor class for databases we can't test in CI."""6970def __init__(71self,72batched: bool,73exact_batch_size: bool,74test_data: pa.Table,75repeat_batch_calls: bool,76) -> None:77self.resultset = MockResultSet(78test_data=test_data,79batched=batched,80exact_batch_size=exact_batch_size,81repeat_batch_calls=repeat_batch_calls,82)83self.exact_batch_size = exact_batch_size84self.called: list[str] = []85self.batched = batched86self.n_calls = 18788def __getattr__(self, name: str) -> Any:89if "fetch" in name:90self.called.append(name)91return self.resultset92super().__getattr__(name) # type: ignore[misc]9394def close(self) -> Any:95pass9697def execute(self, query: str) -> Any:98return self99100101class MockResultSet:102"""Mock resultset class for databases we can't test in CI."""103104def __init__(105self,106test_data: pa.Table,107batched: bool,108exact_batch_size: bool,109repeat_batch_calls: bool = False,110) -> None:111self.test_data = test_data112self.repeat_batched_calls = repeat_batch_calls113self.exact_batch_size = exact_batch_size114self.batched = batched115self.n_calls = 1116117def __call__(self, *args: Any, **kwargs: Any) -> Any:118if not self.exact_batch_size:119assert len(args) == 0120if self.repeat_batched_calls:121res = self.test_data[: None if self.n_calls else 0]122self.n_calls -= 1123else:124res = iter((self.test_data,))125return res126127128class DatabaseReadTestParams(NamedTuple):129"""Clarify read test params."""130131read_method: Literal["read_database", "read_database_uri"]132connect_using: Any133expected_dtypes: SchemaDefinition134expected_dates: list[date | str]135schema_overrides: SchemaDict | None = None136batch_size: int | None = None137138139class ExceptionTestParams(NamedTuple):140"""Clarify exception test params."""141142read_method: str143query: str | list[str]144protocol: Any145errclass: type[Exception]146errmsg: str147engine: str | None = None148execute_options: dict[str, Any] | None = None149pre_execution_query: str | list[str] | None = None150kwargs: dict[str, Any] | None = None151152153@pytest.mark.write_disk154@pytest.mark.parametrize(155(156"read_method",157"connect_using",158"expected_dtypes",159"expected_dates",160"schema_overrides",161"batch_size",162),163[164pytest.param(165*DatabaseReadTestParams(166read_method="read_database_uri",167connect_using="connectorx",168expected_dtypes={169"id": pl.UInt8,170"name": pl.String,171"value": pl.Float64,172"date": pl.Date,173},174expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],175schema_overrides={"id": pl.UInt8},176),177id="uri: connectorx",178),179pytest.param(180*DatabaseReadTestParams(181read_method="read_database_uri",182connect_using="adbc",183expected_dtypes={184"id": pl.UInt8,185"name": pl.String,186"value": pl.Float64,187"date": pl.String,188},189expected_dates=["2020-01-01", "2021-12-31"],190schema_overrides={"id": pl.UInt8},191),192marks=pytest.mark.skipif(193sys.platform == "win32",194reason="adbc_driver_sqlite not available on Windows",195),196id="uri: adbc",197),198pytest.param(199*DatabaseReadTestParams(200read_method="read_database",201connect_using=lambda path: sqlite3.connect(path, detect_types=True),202expected_dtypes={203"id": pl.UInt8,204"name": pl.String,205"value": pl.Float32,206"date": pl.Date,207},208expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],209schema_overrides={"id": pl.UInt8, "value": pl.Float32},210),211id="conn: sqlite3",212),213pytest.param(214*DatabaseReadTestParams(215read_method="read_database",216connect_using=lambda path: sqlite3.connect(path, detect_types=True),217expected_dtypes={218"id": pl.Int32,219"name": pl.String,220"value": pl.Float32,221"date": pl.Date,222},223expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],224schema_overrides={"id": pl.Int32, "value": pl.Float32},225batch_size=1,226),227id="conn: sqlite3 (batched)",228),229pytest.param(230*DatabaseReadTestParams(231read_method="read_database",232connect_using=lambda path: create_engine(233f"sqlite:///{path}",234connect_args={"detect_types": sqlite3.PARSE_DECLTYPES},235).connect(),236expected_dtypes={237"id": pl.Int64,238"name": pl.String,239"value": pl.Float64,240"date": pl.Date,241},242expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],243),244id="conn: sqlalchemy",245),246pytest.param(247*DatabaseReadTestParams(248read_method="read_database",249connect_using=adbc_sqlite_connect,250expected_dtypes={251"id": pl.Int64,252"name": pl.String,253"value": pl.Float64,254"date": pl.String,255},256expected_dates=["2020-01-01", "2021-12-31"],257),258marks=pytest.mark.skipif(259sys.platform == "win32",260reason="adbc_driver_sqlite not available on Windows",261),262id="conn: adbc (fetchall)",263),264pytest.param(265*DatabaseReadTestParams(266read_method="read_database",267connect_using=adbc_sqlite_connect,268expected_dtypes={269"id": pl.Int64,270"name": pl.String,271"value": pl.Float64,272"date": pl.String,273},274expected_dates=["2020-01-01", "2021-12-31"],275batch_size=1,276),277marks=pytest.mark.skipif(278sys.platform == "win32",279reason="adbc_driver_sqlite not available on Windows",280),281id="conn: adbc (batched)",282),283],284)285def test_read_database(286read_method: Literal["read_database", "read_database_uri"],287connect_using: Any,288expected_dtypes: dict[str, pl.DataType],289expected_dates: list[date | str],290schema_overrides: SchemaDict | None,291batch_size: int | None,292tmp_sqlite_db: Path,293) -> None:294if read_method == "read_database_uri":295connect_using = cast("DbReadEngine", connect_using)296# instantiate the connection ourselves, using connectorx/adbc297df = pl.read_database_uri(298uri=f"sqlite:///{tmp_sqlite_db}",299query="SELECT * FROM test_data",300engine=connect_using,301schema_overrides=schema_overrides,302)303df_empty = pl.read_database_uri(304uri=f"sqlite:///{tmp_sqlite_db}",305query="SELECT * FROM test_data WHERE name LIKE '%polars%'",306engine=connect_using,307schema_overrides=schema_overrides,308)309elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]:310# externally instantiated adbc connections311with connect_using(tmp_sqlite_db) as conn:312df = pl.read_database(313connection=conn,314query="SELECT * FROM test_data",315schema_overrides=schema_overrides,316batch_size=batch_size,317)318df_empty = pl.read_database(319connection=conn,320query="SELECT * FROM test_data WHERE name LIKE '%polars%'",321schema_overrides=schema_overrides,322batch_size=batch_size,323)324else:325# other user-supplied connections326df = pl.read_database(327connection=connect_using(tmp_sqlite_db),328query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'",329schema_overrides=schema_overrides,330batch_size=batch_size,331)332df_empty = pl.read_database(333connection=connect_using(tmp_sqlite_db),334query="SELECT * FROM test_data WHERE name LIKE '%polars%'",335schema_overrides=schema_overrides,336batch_size=batch_size,337)338339# validate the expected query return (data and schema)340assert df.schema == expected_dtypes341assert df.shape == (2, 4)342assert df["date"].to_list() == expected_dates343344# note: 'cursor.description' is not reliable when no query345# data is returned, so no point comparing expected dtypes346assert df_empty.columns == ["id", "name", "value", "date"]347assert df_empty.shape == (0, 4)348assert df_empty["date"].to_list() == []349350351@pytest.mark.write_disk352@pytest.mark.parametrize(353(354"read_method",355"connect_using",356"expected_dtypes",357"expected_dates",358"schema_overrides",359"batch_size",360),361[362pytest.param(363*DatabaseReadTestParams(364read_method="read_database",365connect_using=lambda path: sqlite3.connect(path, detect_types=True),366expected_dtypes={367"id": pl.Int32,368"name": pl.String,369"value": pl.Float32,370"date": pl.Date,371},372expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],373schema_overrides={"id": pl.Int32, "value": pl.Float32},374batch_size=1,375),376id="conn: sqlite3",377),378pytest.param(379*DatabaseReadTestParams(380read_method="read_database",381connect_using=lambda path: create_engine(382f"sqlite:///{path}",383connect_args={"detect_types": sqlite3.PARSE_DECLTYPES},384).connect(),385expected_dtypes={386"id": pl.Int64,387"name": pl.String,388"value": pl.Float64,389"date": pl.Date,390},391expected_dates=[date(2020, 1, 1), date(2021, 12, 31)],392batch_size=1,393),394id="conn: sqlalchemy",395),396pytest.param(397*DatabaseReadTestParams(398read_method="read_database",399connect_using=adbc_sqlite_connect,400expected_dtypes={401"id": pl.Int64,402"name": pl.String,403"value": pl.Float64,404"date": pl.String,405},406expected_dates=["2020-01-01", "2021-12-31"],407),408marks=pytest.mark.skipif(409sys.platform == "win32",410reason="adbc_driver_sqlite not available on Windows",411),412id="conn: adbc",413),414pytest.param(415*DatabaseReadTestParams(416read_method="read_database",417connect_using=adbc_sqlite_connect,418expected_dtypes={419"id": pl.Int64,420"name": pl.String,421"value": pl.Float64,422"date": pl.String,423},424expected_dates=["2020-01-01", "2021-12-31"],425batch_size=1,426),427marks=pytest.mark.skipif(428sys.platform == "win32",429reason="adbc_driver_sqlite not available on Windows",430),431id="conn: adbc (ignore batch_size)",432),433],434)435def test_read_database_iter_batches(436read_method: Literal["read_database"],437connect_using: Any,438expected_dtypes: dict[str, pl.DataType],439expected_dates: list[date | str],440schema_overrides: SchemaDict | None,441batch_size: int | None,442tmp_sqlite_db: Path,443) -> None:444if "adbc" in os.environ["PYTEST_CURRENT_TEST"]:445# externally instantiated adbc connections446with connect_using(tmp_sqlite_db) as conn:447dfs = pl.read_database(448connection=conn,449query="SELECT * FROM test_data",450schema_overrides=schema_overrides,451iter_batches=True,452batch_size=batch_size,453)454empty_dfs = pl.read_database(455connection=conn,456query="SELECT * FROM test_data WHERE name LIKE '%polars%'",457schema_overrides=schema_overrides,458iter_batches=True,459batch_size=batch_size,460)461# must consume the iterators while the connection is open462dfs = iter(list(dfs))463empty_dfs = iter(list(empty_dfs))464else:465# other user-supplied connections466dfs = pl.read_database(467connection=connect_using(tmp_sqlite_db),468query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'",469schema_overrides=schema_overrides,470iter_batches=True,471batch_size=batch_size,472)473empty_dfs = pl.read_database(474connection=connect_using(tmp_sqlite_db),475query="SELECT * FROM test_data WHERE name LIKE '%polars%'",476schema_overrides=schema_overrides,477iter_batches=True,478batch_size=batch_size,479)480481df: pl.DataFrame = pl.concat(dfs)482# validate the expected query return (data and schema)483assert df.schema == expected_dtypes484assert df.shape == (2, 4)485assert df["date"].to_list() == expected_dates486487# some drivers return an empty iterator when there is no result488try:489df_empty: pl.DataFrame = pl.concat(empty_dfs)490except ValueError:491return492# # note: 'cursor.description' is not reliable when no query493# # data is returned, so no point comparing expected dtypes494assert df_empty.columns == ["id", "name", "value", "date"]495assert df_empty.shape == (0, 4)496assert df_empty["date"].to_list() == []497498499def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None:500# various flavours of alchemy connection501alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")502alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()503alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()504505t = Table("test_data", MetaData(), autoload_with=alchemy_engine)506507# establish sqlalchemy "selectable" and validate usage508selectable_query = select(509alchemy_cast(func.strftime("%Y", t.c.date), Integer).label("year"),510t.c.name,511t.c.value,512).where(t.c.value < 0)513514expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})515516for conn in (alchemy_session, alchemy_engine, alchemy_conn):517assert_frame_equal(518pl.read_database(selectable_query, connection=conn),519expected,520)521522batches = list(523pl.read_database(524selectable_query,525connection=conn,526iter_batches=True,527batch_size=1,528)529)530assert len(batches) == 1531assert_frame_equal(batches[0], expected)532533534def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None:535# various flavours of alchemy connection536alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")537alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()538alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()539540# establish sqlalchemy "textclause" and validate usage541textclause_query = text(542"""543SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value544FROM test_data545WHERE value < 0546"""547)548549expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})550551for conn in (alchemy_session, alchemy_engine, alchemy_conn):552assert_frame_equal(553pl.read_database(textclause_query, connection=conn),554expected,555)556557batches = list(558pl.read_database(559textclause_query,560connection=conn,561iter_batches=True,562batch_size=1,563)564)565assert len(batches) == 1566assert_frame_equal(batches[0], expected)567568569@pytest.mark.parametrize(570("param", "param_value"),571[572(":n", {"n": 0}),573("?", (0,)),574("?", [0]),575],576)577def test_read_database_parameterised(578param: str, param_value: Any, tmp_sqlite_db: Path579) -> None:580# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs581alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")582alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()583alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()584raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)585586# establish parameterised queries and validate usage587query = """588SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value589FROM test_data590WHERE value < {n}591"""592expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})593594for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):595if conn is alchemy_session and param == "?":596continue # alchemy session.execute() doesn't support positional params597if parse_version(sqlalchemy.__version__) < (2, 0) and param == ":n":598continue # skip for older sqlalchemy versions599600assert_frame_equal(601expected_frame,602pl.read_database(603query.format(n=param),604connection=conn,605execute_options={"parameters": param_value},606),607)608609610@pytest.mark.parametrize(611("param", "param_value"),612[613pytest.param(614":n",615pa.Table.from_pydict({"n": [0]}),616marks=pytest.mark.skip(617reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"618),619),620pytest.param(621":n",622{"n": 0},623marks=pytest.mark.skip(624reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262",625),626),627("?", pa.Table.from_pydict({"data": [0]})),628("?", pl.DataFrame({"data": [0]})),629("?", pl.Series([{"data": 0}])),630("?", (0,)),631("?", [0]),632],633)634@pytest.mark.skipif(635sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"636)637def test_read_database_parameterised_adbc(638param: str, param_value: Any, tmp_sqlite_db: Path639) -> None:640# establish parameterised queries and validate usage641query = """642SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value643FROM test_data644WHERE value < {n}645"""646expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})647648# ADBC will complain in pytest if the connection isn't closed649with adbc_driver_sqlite.dbapi.connect(str(tmp_sqlite_db)) as conn:650assert_frame_equal(651expected_frame,652pl.read_database(653query.format(n=param),654connection=conn,655execute_options={"parameters": param_value},656),657)658659660@pytest.mark.parametrize(661("params", "param_value"),662[663([":lo", ":hi"], {"lo": 90, "hi": 100}),664(["?", "?"], (90, 100)),665(["?", "?"], [90, 100]),666],667)668def test_read_database_parameterised_multiple(669params: list[str], param_value: Any, tmp_sqlite_db: Path670) -> None:671param_1, param_2 = params672# establish parameterised queries and validate usage673query = """674SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value675FROM test_data676WHERE value BETWEEN {param_1} AND {param_2}677"""678expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})679680# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs681alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")682alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()683alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()684raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)685for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):686if alchemy_session is conn and param_1 == "?":687continue # alchemy session.execute() doesn't support positional params688if parse_version(sqlalchemy.__version__) < (2, 0) and isinstance(689param_value, dict690):691continue # skip for older sqlalchemy versions692693assert_frame_equal(694expected_frame,695pl.read_database(696query.format(param_1=param_1, param_2=param_2),697connection=conn,698execute_options={"parameters": param_value},699),700)701702703@pytest.mark.parametrize(704("params", "param_value"),705[706pytest.param(707[":lo", ":hi"],708{"lo": 90, "hi": 100},709marks=pytest.mark.skip(710reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"711),712),713(["?", "?"], pa.Table.from_pydict({"data_1": [90], "data_2": [100]})),714(["?", "?"], pl.DataFrame({"data_1": [90], "data_2": [100]})),715(["?", "?"], pl.Series([{"data_1": 90, "data_2": 100}])),716(["?", "?"], (90, 100)),717(["?", "?"], [90, 100]),718],719)720@pytest.mark.skipif(721sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"722)723def test_read_database_parameterised_multiple_adbc(724params: list[str], param_value: Any, tmp_sqlite_db: Path725) -> None:726param_1, param_2 = params727# establish parameterised queries and validate usage728query = """729SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value730FROM test_data731WHERE value BETWEEN {param_1} AND {param_2}732"""733expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})734735# ADBC will complain in pytest if the connection isn't closed736with adbc_driver_sqlite.dbapi.connect(str(tmp_sqlite_db)) as conn:737assert_frame_equal(738expected_frame,739pl.read_database(740query.format(param_1=param_1, param_2=param_2),741connection=conn,742execute_options={"parameters": param_value},743),744)745746747@pytest.mark.parametrize(748("param", "param_value"),749[750pytest.param(751":n",752pa.Table.from_pydict({"n": [0]}),753marks=pytest.mark.skip(754reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262"755),756),757pytest.param(758":n",759{"n": 0},760marks=pytest.mark.skip(761reason="Named binding not currently supported. See https://github.com/apache/arrow-adbc/issues/3262",762),763),764("?", pa.Table.from_pydict({"data": [0]})),765("?", pl.DataFrame({"data": [0]})),766("?", pl.Series([{"data": 0}])),767("?", (0,)),768("?", [0]),769],770)771@pytest.mark.skipif(772sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"773)774def test_read_database_uri_parameterised(775param: str, param_value: Any, tmp_sqlite_db: Path776) -> None:777alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")778uri = alchemy_engine.url.render_as_string(hide_password=False)779query = """780SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value781FROM test_data782WHERE value < {n}783"""784expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})785786# test URI read method (adbc only)787assert_frame_equal(788expected_frame,789pl.read_database_uri(790query.format(n=param),791uri=uri,792engine="adbc",793execute_options={"parameters": param_value},794),795)796797# no connectorx support for execute_options798with pytest.raises(799ValueError,800match=r"connectorx.*does not support.*execute_options",801):802pl.read_database_uri(803query.format(n=":n"),804uri=uri,805engine="connectorx",806execute_options={"parameters": (":n", {"n": 0})},807)808809810@pytest.mark.parametrize(811("params", "param_value"),812[813pytest.param(814[":lo", ":hi"],815{"lo": 90, "hi": 100},816marks=pytest.mark.xfail(817reason="Named binding not supported. See https://github.com/apache/arrow-adbc/issues/3262",818strict=True,819),820),821(["?", "?"], pa.Table.from_pydict({"data_1": [90], "data_2": [100]})),822(["?", "?"], pl.DataFrame({"data_1": [90], "data_2": [100]})),823(["?", "?"], pl.Series([{"data_1": 90, "data_2": 100}])),824(["?", "?"], (90, 100)),825(["?", "?"], [90, 100]),826],827)828@pytest.mark.skipif(829sys.platform == "win32", reason="adbc_driver_sqlite not available on Windows"830)831def test_read_database_uri_parameterised_multiple(832params: list[str], param_value: Any, tmp_sqlite_db: Path833) -> None:834param_1, param_2 = params835alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")836uri = alchemy_engine.url.render_as_string(hide_password=False)837query = """838SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value839FROM test_data840WHERE value BETWEEN {param_1} AND {param_2}841"""842expected_frame = pl.DataFrame({"year": [2020], "name": ["misc"], "value": [100.0]})843844# test URI read method (ADBC only)845assert_frame_equal(846expected_frame,847pl.read_database_uri(848query.format(param_1=param_1, param_2=param_2),849uri=uri,850engine="adbc",851execute_options={"parameters": param_value},852),853)854855# no connectorx support for execute_options856with pytest.raises(857ValueError,858match=r"connectorx.*does not support.*execute_options",859):860pl.read_database_uri(861query.format(param_1="?", param_2="?"),862uri=uri,863engine="connectorx",864execute_options={"parameters": (90, 100)},865)866867868@pytest.mark.parametrize(869("driver", "batch_size", "iter_batches", "expected_call"),870[871("snowflake", None, False, "fetch_arrow_all"),872("snowflake", 10_000, False, "fetch_arrow_all"),873("snowflake", 10_000, True, "fetch_arrow_batches"),874("databricks", None, False, "fetchall_arrow"),875("databricks", 25_000, False, "fetchall_arrow"),876("databricks", 25_000, True, "fetchmany_arrow"),877("turbodbc", None, False, "fetchallarrow"),878("turbodbc", 50_000, False, "fetchallarrow"),879("turbodbc", 50_000, True, "fetcharrowbatches"),880pytest.param(881"adbc_driver_postgresql",882None,883False,884"fetch_arrow",885marks=pytest.mark.skipif(886sys.platform == "win32",887reason="adbc_driver_postgresql not available on Windows",888),889),890pytest.param(891"adbc_driver_postgresql",89275_000,893False,894"fetch_arrow",895marks=pytest.mark.skipif(896sys.platform == "win32",897reason="adbc_driver_postgresql not available on Windows",898),899),900pytest.param(901"adbc_driver_postgresql",90275_000,903True,904"fetch_record_batch",905marks=pytest.mark.skipif(906sys.platform == "win32",907reason="adbc_driver_postgresql not available on Windows",908),909),910],911)912def test_read_database_mocked(913driver: str, batch_size: int | None, iter_batches: bool, expected_call: str914) -> None:915# since we don't have access to snowflake/databricks/etc from CI we916# mock them so we can check that we're calling the expected methods917arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow()918919reg = ARROW_DRIVER_REGISTRY.get(driver, [{}])[0] # type: ignore[var-annotated]920exact_batch_size = reg.get("exact_batch_size", False)921repeat_batch_calls = reg.get("repeat_batch_calls", False)922923mc = MockConnection(924driver,925batch_size,926test_data=arrow,927repeat_batch_calls=repeat_batch_calls,928exact_batch_size=exact_batch_size, # type: ignore[arg-type]929)930res = pl.read_database(931query="SELECT * FROM test_data",932connection=mc,933iter_batches=iter_batches,934batch_size=batch_size,935)936if iter_batches:937assert isinstance(res, GeneratorType)938res = pl.concat(res)939940res = cast("pl.DataFrame", res)941assert expected_call in mc.cursor().called942assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")]943944945@pytest.mark.parametrize(946(947"read_method",948"query",949"protocol",950"errclass",951"errmsg",952"engine",953"execute_options",954"pre_execution_query",955"kwargs",956),957[958pytest.param(959*ExceptionTestParams(960read_method="read_database_uri",961query="SELECT * FROM test_data",962protocol="sqlite",963errclass=ValueError,964errmsg="engine must be one of {'connectorx', 'adbc'}, got 'not_an_engine'",965engine="not_an_engine",966),967id="Not an available sql engine",968),969pytest.param(970*ExceptionTestParams(971read_method="read_database_uri",972query=["SELECT * FROM test_data", "SELECT * FROM test_data"],973protocol="sqlite",974errclass=ValueError,975errmsg="only a single SQL query string is accepted for adbc, got a 'list' type",976engine="adbc",977),978id="Unavailable list of queries for adbc",979),980pytest.param(981*ExceptionTestParams(982read_method="read_database_uri",983query="SELECT * FROM test_data",984protocol="mysql",985errclass=ModuleNotFoundError,986errmsg="ADBC 'adbc_driver_mysql' driver not detected.",987engine="adbc",988),989id="Unavailable adbc driver",990),991pytest.param(992*ExceptionTestParams(993read_method="read_database_uri",994query="SELECT * FROM test_data",995protocol=sqlite3.connect(":memory:"),996errclass=TypeError,997errmsg="expected connection to be a URI string",998engine="adbc",999),1000id="Invalid connection URI",1001),1002pytest.param(1003*ExceptionTestParams(1004read_method="read_database",1005query="SELECT * FROM imaginary_table",1006protocol=sqlite3.connect(":memory:"),1007errclass=sqlite3.OperationalError,1008errmsg="no such table: imaginary_table",1009),1010id="Invalid query (unrecognised table name)",1011),1012pytest.param(1013*ExceptionTestParams(1014read_method="read_database",1015query="SELECT * FROM imaginary_table",1016protocol=sys.getsizeof, # not a connection1017errclass=TypeError,1018errmsg="Unrecognised connection .* no 'execute' or 'cursor' method",1019),1020id="Invalid read DB kwargs",1021),1022pytest.param(1023*ExceptionTestParams(1024read_method="read_database",1025query="/* tag: misc */ INSERT INTO xyz VALUES ('polars')",1026protocol=sqlite3.connect(":memory:"),1027errclass=UnsuitableSQLError,1028errmsg="INSERT statements are not valid 'read' queries",1029),1030id="Invalid statement type",1031),1032pytest.param(1033*ExceptionTestParams(1034read_method="read_database",1035query="DELETE FROM xyz WHERE id = 'polars'",1036protocol=sqlite3.connect(":memory:"),1037errclass=UnsuitableSQLError,1038errmsg="DELETE statements are not valid 'read' queries",1039),1040id="Invalid statement type",1041),1042pytest.param(1043*ExceptionTestParams(1044read_method="read_database",1045query="SELECT * FROM sqlite_master",1046protocol=sqlite3.connect(":memory:"),1047errclass=ValueError,1048kwargs={"iter_batches": True},1049errmsg="Cannot set `iter_batches` without also setting a non-zero `batch_size`",1050),1051id="Invalid batch_size",1052),1053pytest.param(1054*ExceptionTestParams(1055read_method="read_database",1056engine="adbc",1057query="SELECT * FROM test_data",1058protocol=sqlite3.connect(":memory:"),1059errclass=TypeError,1060errmsg=r"unexpected keyword argument 'partition_on'",1061kwargs={"partition_on": "id"},1062),1063id="Invalid kwargs",1064),1065pytest.param(1066*ExceptionTestParams(1067read_method="read_database",1068engine="adbc",1069query="SELECT * FROM test_data",1070protocol="{not:a, valid:odbc_string}",1071errclass=ValueError,1072errmsg=r"unable to identify string connection as valid ODBC",1073),1074id="Invalid ODBC string",1075),1076pytest.param(1077*ExceptionTestParams(1078read_method="read_database_uri",1079query="SELECT * FROM test_data",1080protocol="sqlite",1081errclass=ValueError,1082errmsg="the 'adbc' engine does not support use of `pre_execution_query`",1083engine="adbc",1084pre_execution_query="SET statement_timeout = 2151",1085),1086id="Unavailable `pre_execution_query` for adbc",1087),1088],1089)1090def test_read_database_exceptions(1091read_method: str,1092query: str,1093protocol: Any,1094errclass: type[Exception],1095errmsg: str,1096engine: DbReadEngine | None,1097execute_options: dict[str, Any] | None,1098pre_execution_query: str | list[str] | None,1099kwargs: dict[str, Any] | None,1100) -> None:1101if read_method == "read_database_uri":1102conn = f"{protocol}://test" if isinstance(protocol, str) else protocol1103params = {1104"uri": conn,1105"query": query,1106"engine": engine,1107"pre_execution_query": pre_execution_query,1108}1109else:1110params = {"connection": protocol, "query": query}1111if execute_options:1112params["execute_options"] = execute_options1113if kwargs is not None:1114params.update(kwargs)11151116read_database = getattr(pl, read_method)1117with pytest.raises(errclass, match=errmsg):1118read_database(**params)111911201121@pytest.mark.parametrize(1122"query",1123[1124"SELECT 1, 1 FROM test_data",1125'SELECT 1 AS "n", 2 AS "n" FROM test_data',1126'SELECT name, value AS "name" FROM test_data',1127],1128)1129def test_read_database_duplicate_column_error(tmp_sqlite_db: Path, query: str) -> None:1130alchemy_conn = create_engine(f"sqlite:///{tmp_sqlite_db}").connect()1131with pytest.raises(1132DuplicateError,1133match=r"column .+ appears more than once in the query/result cursor",1134):1135pl.read_database(query, connection=alchemy_conn)113611371138@pytest.mark.parametrize(1139"uri",1140[1141"fakedb://123:456@account/database/schema?warehouse=warehouse&role=role",1142"fakedb://my#%us3r:p433w0rd@not_a_real_host:9999/database",1143],1144)1145def test_read_database_cx_credentials(uri: str) -> None:1146with pytest.raises(RuntimeError, match=r"Source.*not supported"):1147pl.read_database_uri("SELECT * FROM data", uri=uri, engine="connectorx")114811491150def test_sqlalchemy_row_init(tmp_sqlite_db: Path) -> None:1151expected_frame = pl.DataFrame(1152{1153"id": [1, 2],1154"name": ["misc", "other"],1155"value": [100.0, -99.5],1156"date": ["2020-01-01", "2021-12-31"],1157}1158)1159alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")1160query = text("SELECT * FROM test_data ORDER BY name")11611162with alchemy_engine.connect() as conn:1163# note: sqlalchemy `Row` is a NamedTuple-like object; it additionally has1164# a `_mapping` attribute that returns a `RowMapping` dict-like object. we1165# validate frame/series init from each flavour of query result.1166query_result = list(conn.execute(query))1167for df in (1168pl.DataFrame(query_result),1169pl.DataFrame([row._mapping for row in query_result]),1170pl.from_records([row._mapping for row in query_result]),1171):1172assert_frame_equal(expected_frame, df)11731174expected_series = expected_frame.to_struct()1175for s in (1176pl.Series(query_result),1177pl.Series([row._mapping for row in query_result]),1178):1179assert_series_equal(expected_series, s)118011811182@patch("polars.io.database._utils.from_arrow")1183@patch("polars.io.database._utils.import_optional")1184def test_read_database_uri_pre_execution_query_success(1185import_mock: Mock, from_arrow_mock: Mock1186) -> None:1187cx_mock = Mock()1188cx_mock.__version__ = "0.4.2"11891190import_mock.return_value = cx_mock11911192pre_execution_query = "SET statement_timeout = 2151"11931194pl.read_database_uri(1195query="SELECT 1",1196uri="mysql://test",1197engine="connectorx",1198pre_execution_query=pre_execution_query,1199)12001201assert (1202cx_mock.read_sql.call_args.kwargs["pre_execution_query"] == pre_execution_query1203)120412051206@patch("polars.io.database._utils.import_optional")1207def test_read_database_uri_pre_execution_not_supported_exception(1208import_mock: Mock,1209) -> None:1210cx_mock = Mock()1211cx_mock.__version__ = "0.4.0"12121213import_mock.return_value = cx_mock12141215with (1216pytest.raises(1217ValueError,1218match=r"'pre_execution_query' is only supported in connectorx version 0\.4\.2 or later",1219),1220):1221pl.read_database_uri(1222query="SELECT 1",1223uri="mysql://test",1224engine="connectorx",1225pre_execution_query="SET statement_timeout = 2151",1226)122712281229@patch("polars.io.database._utils.from_arrow")1230@patch("polars.io.database._utils.import_optional")1231def test_read_database_uri_pre_execution_query_not_supported_success(1232import_mock: Mock, from_arrow_mock: Mock1233) -> None:1234cx_mock = Mock()1235cx_mock.__version__ = "0.4.0"12361237import_mock.return_value = cx_mock12381239pl.read_database_uri(1240query="SELECT 1",1241uri="mysql://test",1242engine="connectorx",1243)12441245assert cx_mock.read_sql.call_args.kwargs.get("pre_execution_query") is None124612471248