Path: blob/main/py-polars/tests/unit/io/database/test_inference.py
6939 views
from __future__ import annotations12import sqlite33from typing import TYPE_CHECKING45import pytest67import polars as pl8from polars.exceptions import ComputeError9from polars.io.database._inference import dtype_from_database_typename1011if TYPE_CHECKING:12from pathlib import Path1314from polars._typing import PolarsDataType151617@pytest.mark.parametrize(18("value", "expected_dtype"),19[20# string types21("UTF16", pl.String),22("char(8)", pl.String),23("BPCHAR", pl.String),24("nchar[128]", pl.String),25("varchar", pl.String),26("CHARACTER VARYING(64)", pl.String),27("nvarchar(32)", pl.String),28("TEXT", pl.String),29# array types30("float32[]", pl.List(pl.Float32)),31("double array", pl.List(pl.Float64)),32("array[bool]", pl.List(pl.Boolean)),33("array of nchar(8)", pl.List(pl.String)),34("array[array[int8]]", pl.List(pl.List(pl.Int64))),35# numeric types36("numeric[10,5]", pl.Decimal(10, 5)),37("bigdecimal", pl.Decimal),38("decimal128(10,5)", pl.Decimal(10, 5)),39("double precision", pl.Float64),40("floating point", pl.Float64),41("numeric", pl.Float64),42("real", pl.Float64),43("boolean", pl.Boolean),44("tinyint", pl.Int8),45("smallint", pl.Int16),46("int", pl.Int64),47("int4", pl.Int32),48("int2", pl.Int16),49("int(16)", pl.Int16),50("uint32", pl.UInt32),51("int128", pl.Int128),52("HUGEINT", pl.Int128),53("ROWID", pl.UInt64),54("mediumint", pl.Int32),55("unsigned mediumint", pl.UInt32),56("cardinal_number", pl.UInt64),57("smallserial", pl.Int16),58("serial", pl.Int32),59("bigserial", pl.Int64),60# temporal types61("timestamp(3)", pl.Datetime("ms")),62("timestamp(5)", pl.Datetime("us")),63("timestamp(7)", pl.Datetime("ns")),64("datetime without tz", pl.Datetime("us")),65("duration(2)", pl.Duration("ms")),66("interval", pl.Duration("us")),67("date", pl.Date),68("time", pl.Time),69("date32", pl.Date),70("time64", pl.Time),71# binary types72("BYTEA", pl.Binary),73("BLOB", pl.Binary),74# miscellaneous75("NULL", pl.Null),76],77)78def test_dtype_inference_from_string(79value: str,80expected_dtype: PolarsDataType,81) -> None:82inferred_dtype = dtype_from_database_typename(value)83assert inferred_dtype == expected_dtype # type: ignore[operator]848586@pytest.mark.parametrize(87"value",88[89"FooType",90"Unknown",91"MISSING",92"XML", # note: we deliberately exclude "number" as it is ambiguous.93"Number", # (could refer to any size of int, float, or decimal dtype)94],95)96def test_dtype_inference_from_invalid_string(value: str) -> None:97with pytest.raises(ValueError, match="cannot infer dtype"):98dtype_from_database_typename(value)99100inferred_dtype = dtype_from_database_typename(101value=value,102raise_unmatched=False,103)104assert inferred_dtype is None105106107def test_infer_schema_length(tmp_sqlite_inference_db: Path) -> None:108# note: first row of this test database contains only NULL values109conn = sqlite3.connect(tmp_sqlite_inference_db)110for infer_len in (2, 100, None):111df = pl.read_database(112connection=conn,113query="SELECT * FROM test_data",114infer_schema_length=infer_len,115)116assert df.schema == {"name": pl.String, "value": pl.Float64}117118with pytest.raises(119ComputeError,120match='could not append value: "foo" of type: str.*`infer_schema_length`',121):122pl.read_database(123connection=conn,124query="SELECT * FROM test_data",125infer_schema_length=1,126)127128129