Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/database/test_inference.py
6939 views
1
from __future__ import annotations
2
3
import sqlite3
4
from typing import TYPE_CHECKING
5
6
import pytest
7
8
import polars as pl
9
from polars.exceptions import ComputeError
10
from polars.io.database._inference import dtype_from_database_typename
11
12
if TYPE_CHECKING:
13
from pathlib import Path
14
15
from polars._typing import PolarsDataType
16
17
18
@pytest.mark.parametrize(
19
("value", "expected_dtype"),
20
[
21
# string types
22
("UTF16", pl.String),
23
("char(8)", pl.String),
24
("BPCHAR", pl.String),
25
("nchar[128]", pl.String),
26
("varchar", pl.String),
27
("CHARACTER VARYING(64)", pl.String),
28
("nvarchar(32)", pl.String),
29
("TEXT", pl.String),
30
# array types
31
("float32[]", pl.List(pl.Float32)),
32
("double array", pl.List(pl.Float64)),
33
("array[bool]", pl.List(pl.Boolean)),
34
("array of nchar(8)", pl.List(pl.String)),
35
("array[array[int8]]", pl.List(pl.List(pl.Int64))),
36
# numeric types
37
("numeric[10,5]", pl.Decimal(10, 5)),
38
("bigdecimal", pl.Decimal),
39
("decimal128(10,5)", pl.Decimal(10, 5)),
40
("double precision", pl.Float64),
41
("floating point", pl.Float64),
42
("numeric", pl.Float64),
43
("real", pl.Float64),
44
("boolean", pl.Boolean),
45
("tinyint", pl.Int8),
46
("smallint", pl.Int16),
47
("int", pl.Int64),
48
("int4", pl.Int32),
49
("int2", pl.Int16),
50
("int(16)", pl.Int16),
51
("uint32", pl.UInt32),
52
("int128", pl.Int128),
53
("HUGEINT", pl.Int128),
54
("ROWID", pl.UInt64),
55
("mediumint", pl.Int32),
56
("unsigned mediumint", pl.UInt32),
57
("cardinal_number", pl.UInt64),
58
("smallserial", pl.Int16),
59
("serial", pl.Int32),
60
("bigserial", pl.Int64),
61
# temporal types
62
("timestamp(3)", pl.Datetime("ms")),
63
("timestamp(5)", pl.Datetime("us")),
64
("timestamp(7)", pl.Datetime("ns")),
65
("datetime without tz", pl.Datetime("us")),
66
("duration(2)", pl.Duration("ms")),
67
("interval", pl.Duration("us")),
68
("date", pl.Date),
69
("time", pl.Time),
70
("date32", pl.Date),
71
("time64", pl.Time),
72
# binary types
73
("BYTEA", pl.Binary),
74
("BLOB", pl.Binary),
75
# miscellaneous
76
("NULL", pl.Null),
77
],
78
)
79
def test_dtype_inference_from_string(
80
value: str,
81
expected_dtype: PolarsDataType,
82
) -> None:
83
inferred_dtype = dtype_from_database_typename(value)
84
assert inferred_dtype == expected_dtype # type: ignore[operator]
85
86
87
@pytest.mark.parametrize(
88
"value",
89
[
90
"FooType",
91
"Unknown",
92
"MISSING",
93
"XML", # note: we deliberately exclude "number" as it is ambiguous.
94
"Number", # (could refer to any size of int, float, or decimal dtype)
95
],
96
)
97
def test_dtype_inference_from_invalid_string(value: str) -> None:
98
with pytest.raises(ValueError, match="cannot infer dtype"):
99
dtype_from_database_typename(value)
100
101
inferred_dtype = dtype_from_database_typename(
102
value=value,
103
raise_unmatched=False,
104
)
105
assert inferred_dtype is None
106
107
108
def test_infer_schema_length(tmp_sqlite_inference_db: Path) -> None:
109
# note: first row of this test database contains only NULL values
110
conn = sqlite3.connect(tmp_sqlite_inference_db)
111
for infer_len in (2, 100, None):
112
df = pl.read_database(
113
connection=conn,
114
query="SELECT * FROM test_data",
115
infer_schema_length=infer_len,
116
)
117
assert df.schema == {"name": pl.String, "value": pl.Float64}
118
119
with pytest.raises(
120
ComputeError,
121
match='could not append value: "foo" of type: str.*`infer_schema_length`',
122
):
123
pl.read_database(
124
connection=conn,
125
query="SELECT * FROM test_data",
126
infer_schema_length=1,
127
)
128
129