Path: blob/main/py-polars/tests/unit/io/database/test_write.py
6939 views
from __future__ import annotations12import sys3from typing import TYPE_CHECKING, Any45import pytest6from sqlalchemy import create_engine7from sqlalchemy.orm import Session8from sqlalchemy.pool import NullPool910import polars as pl11from polars.io.database._utils import _open_adbc_connection12from polars.testing import assert_frame_equal1314if TYPE_CHECKING:15from pathlib import Path1617from polars._typing import DbWriteEngine181920@pytest.mark.write_disk21@pytest.mark.parametrize(22("engine", "uri_connection"),23[24("sqlalchemy", True),25("sqlalchemy", False),26pytest.param(27"adbc",28True,29marks=pytest.mark.skipif(30sys.platform == "win32",31reason="adbc not available on Windows",32),33),34pytest.param(35"adbc",36False,37marks=pytest.mark.skipif(38sys.platform == "win32",39reason="adbc not available on Windows",40),41),42],43)44class TestWriteDatabase:45"""Database write tests that share common pytest/parametrize options."""4647@staticmethod48def _get_connection(uri: str, engine: DbWriteEngine, uri_connection: bool) -> Any:49if uri_connection:50return uri51elif engine == "sqlalchemy":52return create_engine(uri)53else:54return _open_adbc_connection(uri)5556def test_write_database_create(57self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path58) -> None:59"""Test basic database table creation."""60df = pl.DataFrame(61{62"id": [1234, 5678],63"name": ["misc", "other"],64"value": [1000.0, -9999.0],65}66)67tmp_path.mkdir(exist_ok=True)68test_db_uri = f"sqlite:///{tmp_path}/test_create_{int(uri_connection)}.db"6970table_name = "test_create"71conn = self._get_connection(test_db_uri, engine, uri_connection)7273assert (74df.write_database(75table_name=table_name,76connection=conn,77engine=engine,78)79== 280)81result = pl.read_database(82query=f"SELECT * FROM {table_name}",83connection=create_engine(test_db_uri),84)85assert_frame_equal(result, df)8687if hasattr(conn, "close"):88conn.close()8990def test_write_database_append_replace(91self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path92) -> None:93"""Test append/replace ops against existing database table."""94df = pl.DataFrame(95{96"key": ["xx", "yy", "zz"],97"value": [123, None, 789],98"other": [5.5, 7.0, None],99}100)101tmp_path.mkdir(exist_ok=True)102test_db_uri = f"sqlite:///{tmp_path}/test_append_{int(uri_connection)}.db"103104table_name = "test_append"105conn = self._get_connection(test_db_uri, engine, uri_connection)106107assert (108df.write_database(109table_name=table_name,110connection=conn,111engine=engine,112)113== 3114)115with pytest.raises(Exception): # noqa: B017116df.write_database(117table_name=table_name,118connection=conn,119if_table_exists="fail",120engine=engine,121)122123assert (124df.write_database(125table_name=table_name,126connection=conn,127if_table_exists="replace",128engine=engine,129)130== 3131)132result = pl.read_database(133query=f"SELECT * FROM {table_name}",134connection=create_engine(test_db_uri),135)136assert_frame_equal(result, df)137138assert (139df[:2].write_database(140table_name=table_name,141connection=conn,142if_table_exists="append",143engine=engine,144)145== 2146)147result = pl.read_database(148query=f"SELECT * FROM {table_name}",149connection=create_engine(test_db_uri),150)151assert_frame_equal(result, pl.concat([df, df[:2]]))152153if engine == "adbc" and not uri_connection:154assert conn._closed is False155156if hasattr(conn, "close"):157conn.close()158159def test_write_database_create_quoted_tablename(160self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path161) -> None:162"""Test parsing/handling of quoted database table names."""163df = pl.DataFrame(164{165"col x": [100, 200, 300],166"col y": ["a", "b", "c"],167}168)169tmp_path.mkdir(exist_ok=True)170test_db_uri = f"sqlite:///{tmp_path}/test_create_quoted.db"171172# table name has some special chars, so requires quoting, and173# is explicitly qualified with the sqlite 'main' schema174qualified_table_name = f'main."test-append-{engine}-{int(uri_connection)}"'175conn = self._get_connection(test_db_uri, engine, uri_connection)176177assert (178df.write_database(179table_name=qualified_table_name,180connection=conn,181engine=engine,182)183== 3184)185assert (186df.write_database(187table_name=qualified_table_name,188connection=conn,189if_table_exists="replace",190engine=engine,191)192== 3193)194result = pl.read_database(195query=f"SELECT * FROM {qualified_table_name}",196connection=create_engine(test_db_uri),197)198assert_frame_equal(result, df)199200if engine == "adbc" and not uri_connection:201assert conn._closed is False202203if hasattr(conn, "close"):204conn.close()205206def test_write_database_errors(207self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path208) -> None:209"""Confirm that expected errors are raised."""210df = pl.DataFrame({"colx": [1, 2, 3]})211212with pytest.raises(213ValueError, match="`table_name` appears to be invalid: 'w.x.y.z'"214):215df.write_database(216connection="sqlite:///:memory:",217table_name="w.x.y.z",218engine=engine,219)220221with pytest.raises(222ValueError,223match="`if_table_exists` must be one of .* got 'do_something'",224):225df.write_database(226connection="sqlite:///:memory:",227table_name="main.test_errs",228if_table_exists="do_something", # type: ignore[arg-type]229engine=engine,230)231232with pytest.raises(233TypeError,234match="unrecognised connection type.*",235):236df.write_database(connection=True, table_name="misc", engine=engine) # type: ignore[arg-type]237238def test_write_database_adbc_missing_driver_error(239self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path240) -> None:241# Skip for sqlalchemy242if engine == "sqlalchemy":243return244df = pl.DataFrame({"colx": [1, 2, 3]})245with pytest.raises(246ModuleNotFoundError, match="ADBC 'adbc_driver_mysql' driver not detected."247):248df.write_database(249table_name="my_schema.my_table",250connection="mysql:///:memory:",251engine=engine,252)253254255@pytest.mark.write_disk256def test_write_database_using_sa_session(tmp_path: str) -> None:257df = pl.DataFrame(258{259"key": ["xx", "yy", "zz"],260"value": [123, None, 789],261"other": [5.5, 7.0, None],262}263)264table_name = "test_sa_session"265test_db_uri = f"sqlite:///{tmp_path}/test_sa_session.db"266engine = create_engine(test_db_uri, poolclass=NullPool)267with Session(engine) as session:268df.write_database(table_name, session)269session.commit()270271with Session(engine) as session:272result = pl.read_database(273query=f"select * from {table_name}", connection=session274)275276assert_frame_equal(result, df)277278279@pytest.mark.write_disk280@pytest.mark.parametrize("pass_connection", [True, False])281def test_write_database_sa_rollback(tmp_path: str, pass_connection: bool) -> None:282df = pl.DataFrame(283{284"key": ["xx", "yy", "zz"],285"value": [123, None, 789],286"other": [5.5, 7.0, None],287}288)289table_name = "test_sa_rollback"290test_db_uri = f"sqlite:///{tmp_path}/test_sa_rollback.db"291engine = create_engine(test_db_uri, poolclass=NullPool)292with Session(engine) as session:293if pass_connection:294conn = session.connection()295df.write_database(table_name, conn)296else:297df.write_database(table_name, session)298session.rollback()299300with Session(engine) as session:301count = pl.read_database(302query=f"select count(*) from {table_name}", connection=session303).item(0, 0)304305assert isinstance(count, int)306assert count == 0307308309@pytest.mark.write_disk310@pytest.mark.parametrize("pass_connection", [True, False])311def test_write_database_sa_commit(tmp_path: str, pass_connection: bool) -> None:312df = pl.DataFrame(313{314"key": ["xx", "yy", "zz"],315"value": [123, None, 789],316"other": [5.5, 7.0, None],317}318)319table_name = "test_sa_commit"320test_db_uri = f"sqlite:///{tmp_path}/test_sa_commit.db"321engine = create_engine(test_db_uri, poolclass=NullPool)322with Session(engine) as session:323if pass_connection:324conn = session.connection()325df.write_database(table_name, conn)326else:327df.write_database(table_name, session)328session.commit()329330with Session(engine) as session:331result = pl.read_database(332query=f"select * from {table_name}", connection=session333)334335assert_frame_equal(result, df)336337338@pytest.mark.skipif(sys.platform == "win32", reason="adbc not available on Windows")339def test_write_database_adbc_temporary_table() -> None:340"""Confirm that execution_options are passed along to create temporary tables."""341df = pl.DataFrame({"colx": [1, 2, 3]})342temp_tbl_name = "should_be_temptable"343expected_temp_table_create_sql = (344"""CREATE TABLE "should_be_temptable" ("colx" INTEGER)"""345)346347# test with sqlite in memory348conn = _open_adbc_connection("sqlite:///:memory:")349assert (350df.write_database(351temp_tbl_name,352connection=conn,353if_table_exists="fail",354engine_options={"temporary": True},355)356== 3357)358temp_tbl_sql_df = pl.read_database(359"select sql from sqlite_temp_master where type='table' and tbl_name = ?",360connection=conn,361execute_options={"parameters": [temp_tbl_name]},362)363assert temp_tbl_sql_df.shape[0] == 1, "no temp table created"364actual_temp_table_create_sql = temp_tbl_sql_df["sql"][0]365assert expected_temp_table_create_sql == actual_temp_table_create_sql366367if hasattr(conn, "close"):368conn.close()369370371