Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/asserts.py
7884 views
1
from __future__ import annotations
2
3
import contextlib
4
import sqlite3
5
from typing import TYPE_CHECKING, Any, Literal
6
7
import pytest
8
9
import polars as pl
10
from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES
11
from polars.testing import assert_frame_equal
12
13
if TYPE_CHECKING:
14
from collections.abc import Collection, Sequence
15
16
from polars.type_aliases import PolarsDataType
17
18
_POLARS_TO_SQLITE_: dict[PolarsDataType, str] = {
19
# SQLite has limited type support (primitive scalar types only)
20
**dict.fromkeys(INTEGER_DTYPES, "INTEGER"),
21
**dict.fromkeys(FLOAT_DTYPES, "FLOAT"),
22
pl.Boolean: "INTEGER",
23
pl.String: "TEXT",
24
}
25
26
27
def _execute_with_sqlite(
28
frames: dict[str, pl.DataFrame | pl.LazyFrame],
29
query: str,
30
) -> pl.DataFrame:
31
"""Execute a SQL query against SQLite, returning a DataFrame."""
32
with contextlib.closing(sqlite3.connect(":memory:")) as conn:
33
cursor = conn.cursor()
34
for name, df in frames.items():
35
if isinstance(df, pl.LazyFrame):
36
df = df.collect()
37
38
frame_schema = df.schema
39
types = (_POLARS_TO_SQLITE_[frame_schema[col]] for col in df.columns)
40
schema = ", ".join(f"{col} {tp}" for col, tp in zip(df.columns, types))
41
cursor.execute(f"CREATE TABLE {name} ({schema})")
42
cursor.executemany(
43
f"INSERT INTO {name} VALUES ({','.join(['?'] * len(df.columns))})",
44
df.iter_rows(),
45
)
46
47
conn.commit()
48
cursor.execute(query)
49
50
return pl.DataFrame(
51
cursor.fetchall(),
52
schema=[desc[0] for desc in cursor.description],
53
orient="row",
54
)
55
56
57
def _execute_with_duckdb(
58
frames: dict[str, pl.DataFrame | pl.LazyFrame],
59
query: str,
60
) -> pl.DataFrame:
61
"""Execute a SQL query against DuckDB, returning a DataFrame."""
62
try:
63
import duckdb
64
except ImportError:
65
# if not available locally, skip (will always be run on CI)
66
pytest.skip(
67
"""DuckDB not installed; required for `assert_sql_matches` with "compare_with='duckdb'"."""
68
)
69
with duckdb.connect(":memory:") as conn:
70
for name, df in frames.items():
71
conn.register(name, df)
72
return conn.execute(query).pl() # type: ignore[no-any-return]
73
74
75
_COMPARISON_BACKENDS_ = {
76
"sqlite": _execute_with_sqlite,
77
"duckdb": _execute_with_duckdb,
78
}
79
80
81
def assert_sql_matches(
82
frames: pl.DataFrame | pl.LazyFrame | dict[str, pl.DataFrame | pl.LazyFrame],
83
*,
84
query: str,
85
compare_with: Literal["sqlite", "duckdb"] | Collection[Literal["sqlite", "duckdb"]],
86
check_dtypes: bool = False,
87
check_row_order: bool = True,
88
check_column_names: bool = True,
89
expected: pl.DataFrame | dict[str, Sequence[Any]] | None = None,
90
) -> bool:
91
"""
92
Assert that a Polars SQL query produces the same result as a reference backend.
93
94
This function executes the provided SQL query using both Polars and a reference
95
SQL engine (eg: SQLite or DuckDB), then asserts that the results match.
96
97
Parameters
98
----------
99
frames
100
Mapping of table names to DataFrame or LazyFrame; the query should reference
101
the table names as they appear in the dict keys. If passed a single frame,
102
"self" is assumed to be the name of the referenced table/frame.
103
query
104
SQL query string to test, referencing table names from `frames`.
105
compare_with
106
One or more named SQL engines to use as a reference for comparison.
107
- 'sqlite': Use Python's built-in `sqlite3` module.
108
- 'duckdb': Use DuckDB (requires `duckdb` to be installed separately).
109
check_dtypes
110
Require that the comparison frame dtypes match; defaults to False, as different
111
backends may use different type systems, and we care about the values.
112
check_row_order
113
Set False to ignore the row order in the Polars/comparison frame match.
114
check_column_names
115
Set False to ignore the column names in the Polars/comparison frame match
116
(but still compare each column in the same expected order).
117
expected
118
An optional DataFrame (or dictionary) containing the expected result;
119
with this we can confirm both that the result matches the reference
120
implementation *and* that those results match expectation.
121
122
Examples
123
--------
124
>>> import polars as pl
125
>>> from tests.unit.sql import assert_sql_matches
126
127
Confirm that a given SQL query against a single frame returns the same
128
result values when executed with Polars and executed with SQLite:
129
130
>>> lf = pl.LazyFrame({"lbl": ["xx", "yy", "zz"], "value": [-150, 325, 275]})
131
>>> query = "SELECT lbl, value * 2 AS doubled FROM demo WHERE id > 1 ORDER BY lbl"
132
>>> assert_sql_matches({"demo": lf}, query=query, compare_with="sqlite")
133
134
Check that a multi-frame JOIN produces the same result as DuckDB:
135
136
>>> users = pl.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]})
137
>>> orders = pl.DataFrame({"user_id": [1, 1, 2], "amount": [100, 200, 150]})
138
>>> assert_sql_matches(
139
... frames={"users": users, "orders": orders},
140
... query='''
141
... SELECT u.name, SUM(o.amount) as total
142
... FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name
143
... ''',
144
... compare_with="duckdb",
145
... check_row_order=False,
146
... )
147
"""
148
if isinstance(frames, (pl.DataFrame, pl.LazyFrame)):
149
frames = {"self": frames}
150
151
with pl.SQLContext(frames=frames, eager=True) as ctx:
152
polars_result = ctx.execute(query=query, eager=True)
153
154
if isinstance(compare_with, str):
155
compare_with = [compare_with]
156
157
for comparison_backend in compare_with:
158
if (exec_comparison := _COMPARISON_BACKENDS_.get(comparison_backend)) is None:
159
valid_engines = ", ".join(repr(b) for b in sorted(_COMPARISON_BACKENDS_))
160
msg = (
161
f"invalid `compare_with` value: {comparison_backend!r}; "
162
f"expected one of {valid_engines}"
163
)
164
raise ValueError(msg)
165
166
comparison_result = exec_comparison(frames, query)
167
if not check_column_names:
168
n_comparison_cols = comparison_result.width
169
comparison_result.columns = polars_result.columns[:n_comparison_cols]
170
171
# validate against the reference engine/backend
172
assert_frame_equal(
173
polars_result,
174
comparison_result,
175
check_dtypes=check_dtypes,
176
check_row_order=check_row_order,
177
)
178
179
# confirm that these values are not just consistent
180
# but also match a specific/expected result
181
if expected is not None:
182
if isinstance(expected, dict):
183
expected = pl.from_dict(
184
data=expected,
185
schema=polars_result.schema,
186
)
187
188
assert_frame_equal(
189
polars_result,
190
expected,
191
check_dtypes=check_dtypes,
192
check_row_order=check_row_order,
193
)
194
195
return True
196
197
198
__all__ = ["assert_sql_matches"]
199
200