from __future__ import annotations
import contextlib
import sqlite3
from typing import TYPE_CHECKING, Any, Literal
import pytest
import polars as pl
from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES
from polars.testing import assert_frame_equal
if TYPE_CHECKING:
from collections.abc import Collection, Sequence
from polars.type_aliases import PolarsDataType
_POLARS_TO_SQLITE_: dict[PolarsDataType, str] = {
**dict.fromkeys(INTEGER_DTYPES, "INTEGER"),
**dict.fromkeys(FLOAT_DTYPES, "FLOAT"),
pl.Boolean: "INTEGER",
pl.String: "TEXT",
}
def _execute_with_sqlite(
frames: dict[str, pl.DataFrame | pl.LazyFrame],
query: str,
) -> pl.DataFrame:
"""Execute a SQL query against SQLite, returning a DataFrame."""
with contextlib.closing(sqlite3.connect(":memory:")) as conn:
cursor = conn.cursor()
for name, df in frames.items():
if isinstance(df, pl.LazyFrame):
df = df.collect()
frame_schema = df.schema
types = (_POLARS_TO_SQLITE_[frame_schema[col]] for col in df.columns)
schema = ", ".join(f"{col} {tp}" for col, tp in zip(df.columns, types))
cursor.execute(f"CREATE TABLE {name} ({schema})")
cursor.executemany(
f"INSERT INTO {name} VALUES ({','.join(['?'] * len(df.columns))})",
df.iter_rows(),
)
conn.commit()
cursor.execute(query)
return pl.DataFrame(
cursor.fetchall(),
schema=[desc[0] for desc in cursor.description],
orient="row",
)
def _execute_with_duckdb(
frames: dict[str, pl.DataFrame | pl.LazyFrame],
query: str,
) -> pl.DataFrame:
"""Execute a SQL query against DuckDB, returning a DataFrame."""
try:
import duckdb
except ImportError:
pytest.skip(
"""DuckDB not installed; required for `assert_sql_matches` with "compare_with='duckdb'"."""
)
with duckdb.connect(":memory:") as conn:
for name, df in frames.items():
conn.register(name, df)
return conn.execute(query).pl()
_COMPARISON_BACKENDS_ = {
"sqlite": _execute_with_sqlite,
"duckdb": _execute_with_duckdb,
}
def assert_sql_matches(
frames: pl.DataFrame | pl.LazyFrame | dict[str, pl.DataFrame | pl.LazyFrame],
*,
query: str,
compare_with: Literal["sqlite", "duckdb"] | Collection[Literal["sqlite", "duckdb"]],
check_dtypes: bool = False,
check_row_order: bool = True,
check_column_names: bool = True,
expected: pl.DataFrame | dict[str, Sequence[Any]] | None = None,
) -> bool:
"""
Assert that a Polars SQL query produces the same result as a reference backend.
This function executes the provided SQL query using both Polars and a reference
SQL engine (eg: SQLite or DuckDB), then asserts that the results match.
Parameters
----------
frames
Mapping of table names to DataFrame or LazyFrame; the query should reference
the table names as they appear in the dict keys. If passed a single frame,
"self" is assumed to be the name of the referenced table/frame.
query
SQL query string to test, referencing table names from `frames`.
compare_with
One or more named SQL engines to use as a reference for comparison.
- 'sqlite': Use Python's built-in `sqlite3` module.
- 'duckdb': Use DuckDB (requires `duckdb` to be installed separately).
check_dtypes
Require that the comparison frame dtypes match; defaults to False, as different
backends may use different type systems, and we care about the values.
check_row_order
Set False to ignore the row order in the Polars/comparison frame match.
check_column_names
Set False to ignore the column names in the Polars/comparison frame match
(but still compare each column in the same expected order).
expected
An optional DataFrame (or dictionary) containing the expected result;
with this we can confirm both that the result matches the reference
implementation *and* that those results match expectation.
Examples
--------
>>> import polars as pl
>>> from tests.unit.sql import assert_sql_matches
Confirm that a given SQL query against a single frame returns the same
result values when executed with Polars and executed with SQLite:
>>> lf = pl.LazyFrame({"lbl": ["xx", "yy", "zz"], "value": [-150, 325, 275]})
>>> query = "SELECT lbl, value * 2 AS doubled FROM demo WHERE id > 1 ORDER BY lbl"
>>> assert_sql_matches({"demo": lf}, query=query, compare_with="sqlite")
Check that a multi-frame JOIN produces the same result as DuckDB:
>>> users = pl.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]})
>>> orders = pl.DataFrame({"user_id": [1, 1, 2], "amount": [100, 200, 150]})
>>> assert_sql_matches(
... frames={"users": users, "orders": orders},
... query='''
... SELECT u.name, SUM(o.amount) as total
... FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name
... ''',
... compare_with="duckdb",
... check_row_order=False,
... )
"""
if isinstance(frames, (pl.DataFrame, pl.LazyFrame)):
frames = {"self": frames}
with pl.SQLContext(frames=frames, eager=True) as ctx:
polars_result = ctx.execute(query=query, eager=True)
if isinstance(compare_with, str):
compare_with = [compare_with]
for comparison_backend in compare_with:
if (exec_comparison := _COMPARISON_BACKENDS_.get(comparison_backend)) is None:
valid_engines = ", ".join(repr(b) for b in sorted(_COMPARISON_BACKENDS_))
msg = (
f"invalid `compare_with` value: {comparison_backend!r}; "
f"expected one of {valid_engines}"
)
raise ValueError(msg)
comparison_result = exec_comparison(frames, query)
if not check_column_names:
n_comparison_cols = comparison_result.width
comparison_result.columns = polars_result.columns[:n_comparison_cols]
assert_frame_equal(
polars_result,
comparison_result,
check_dtypes=check_dtypes,
check_row_order=check_row_order,
)
if expected is not None:
if isinstance(expected, dict):
expected = pl.from_dict(
data=expected,
schema=polars_result.schema,
)
assert_frame_equal(
polars_result,
expected,
check_dtypes=check_dtypes,
check_row_order=check_row_order,
)
return True
__all__ = ["assert_sql_matches"]