Path: blob/main/py-polars/tests/unit/io/test_scan_options.py
8424 views
from __future__ import annotations12import io3from datetime import datetime4from typing import IO, TYPE_CHECKING, Any5from zoneinfo import ZoneInfo67import pytest89import polars as pl10from polars.datatypes.group import FLOAT_DTYPES11from polars.exceptions import SchemaError12from polars.testing import assert_frame_equal1314if TYPE_CHECKING:15from collections.abc import Callable161718@pytest.mark.parametrize(19("literal_values", "expected", "cast_options"),20[21(22(pl.lit(1, dtype=pl.Int64), pl.lit(2, dtype=pl.Int32)),23pl.Series([1, 2], dtype=pl.Int64),24pl.ScanCastOptions(integer_cast="upcast"),25),26(27(pl.lit(1.0, dtype=pl.Float64), pl.lit(2.0, dtype=pl.Float32)),28pl.Series([1, 2], dtype=pl.Float64),29pl.ScanCastOptions(float_cast="upcast"),30),31(32(pl.lit(1.0, dtype=pl.Float32), pl.lit(2.0, dtype=pl.Float64)),33pl.Series([1, 2], dtype=pl.Float32),34pl.ScanCastOptions(float_cast=["upcast", "downcast"]),35),36(37(38pl.lit(datetime(2025, 1, 1), dtype=pl.Datetime(time_unit="ms")),39pl.lit(datetime(2025, 1, 2), dtype=pl.Datetime(time_unit="ns")),40),41pl.Series(42[datetime(2025, 1, 1), datetime(2025, 1, 2)],43dtype=pl.Datetime(time_unit="ms"),44),45pl.ScanCastOptions(datetime_cast="nanosecond-downcast"),46),47(48(49pl.lit(50datetime(2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")),51dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Amsterdam"),52),53pl.lit(54datetime(2025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")),55dtype=pl.Datetime(time_unit="ns", time_zone="Australia/Sydney"),56),57),58pl.Series(59[60datetime(2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")),61datetime(2025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")),62],63dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Amsterdam"),64),65pl.ScanCastOptions(66datetime_cast=["nanosecond-downcast", "convert-timezone"]67),68),69(70( # We also test nested primitive upcast policy with this one71pl.lit(72{"a": [[1]], "b": 1},73dtype=pl.Struct(74{"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}75),76),77pl.lit(78{"a": [[2]]},79dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int8, 1))}),80),81),82pl.Series(83[{"a": [[1]], "b": 1}, {"a": [[2]], "b": None}],84dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}),85),86pl.ScanCastOptions(87integer_cast="upcast",88missing_struct_fields="insert",89),90),91(92( # Test same set of struct fields but in different order93pl.lit(94{"a": [[1]], "b": 1},95dtype=pl.Struct(96{"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}97),98),99pl.lit(100{"b": None, "a": [[2]]},101dtype=pl.Struct(102{"b": pl.Int32, "a": pl.List(pl.Array(pl.Int32, 1))}103),104),105),106pl.Series(107[{"a": [[1]], "b": 1}, {"a": [[2]], "b": None}],108dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}),109),110None,111),112# Test logical (datetime) type under list113(114(115pl.lit(116[117{118"field": datetime(1192025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")120)121}122],123dtype=pl.List(124pl.Struct(125{126"field": pl.Datetime(127time_unit="ms", time_zone="Europe/Amsterdam"128)129}130)131),132),133pl.lit(134[135{136"field": datetime(1372025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")138)139}140],141dtype=pl.List(142pl.Struct(143{144"field": pl.Datetime(145time_unit="ns", time_zone="Australia/Sydney"146)147}148)149),150),151),152pl.Series(153[154[155{156"field": datetime(1572025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")158)159}160],161[162{163"field": datetime(1642025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")165)166}167],168],169dtype=pl.List(170pl.Struct(171{172"field": pl.Datetime(173time_unit="ms", time_zone="Europe/Amsterdam"174)175}176)177),178),179pl.ScanCastOptions(180datetime_cast=["nanosecond-downcast", "convert-timezone"]181),182),183(184(185pl.lit(186[187{188"field": datetime(1892025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")190)191}192],193dtype=pl.Array(194pl.Struct(195{196"field": pl.Datetime(197time_unit="ms", time_zone="Europe/Amsterdam"198)199}200),201shape=1,202),203),204pl.lit(205[206{207"field": datetime(2082025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")209)210}211],212dtype=pl.Array(213pl.Struct(214{215"field": pl.Datetime(216time_unit="ns", time_zone="Australia/Sydney"217)218}219),220shape=1,221),222),223),224pl.Series(225[226[227{228"field": datetime(2292025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")230)231}232],233[234{235"field": datetime(2362025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")237)238}239],240],241dtype=pl.Array(242pl.Struct(243{244"field": pl.Datetime(245time_unit="ms", time_zone="Europe/Amsterdam"246)247}248),249shape=1,250),251),252pl.ScanCastOptions(253datetime_cast=["nanosecond-downcast", "convert-timezone"]254),255),256# Test outer validity257(258(259pl.lit(260None,261dtype=pl.List(262pl.Struct(263{264"field": pl.Datetime(265time_unit="ms", time_zone="Europe/Amsterdam"266)267}268)269),270),271pl.lit(272[None],273dtype=pl.List(274pl.Struct(275{276"field": pl.Datetime(277time_unit="ns", time_zone="Australia/Sydney"278)279}280)281),282),283),284pl.Series(285[None, [None]],286dtype=pl.List(287pl.Struct(288{289"field": pl.Datetime(290time_unit="ms", time_zone="Europe/Amsterdam"291)292}293)294),295),296pl.ScanCastOptions(297datetime_cast=["nanosecond-downcast", "convert-timezone"]298),299),300(301(302pl.lit(303None,304dtype=pl.Array(305pl.Struct(306{307"field": pl.Datetime(308time_unit="ms", time_zone="Europe/Amsterdam"309)310}311),312shape=1,313),314),315pl.lit(316[None],317dtype=pl.Array(318pl.Struct(319{320"field": pl.Datetime(321time_unit="ns", time_zone="Australia/Sydney"322)323}324),325shape=1,326),327),328),329pl.Series(330[None, [None]],331dtype=pl.Array(332pl.Struct(333{334"field": pl.Datetime(335time_unit="ms", time_zone="Europe/Amsterdam"336)337}338),339shape=1,340),341),342pl.ScanCastOptions(343datetime_cast=["nanosecond-downcast", "convert-timezone"]344),345),346],347)348def test_scan_cast_options(349literal_values: tuple[pl.Expr, pl.Expr],350expected: pl.Series,351cast_options: pl.ScanCastOptions | None,352) -> None:353expected = expected.alias("literal")354lv1, lv2 = literal_values355356df1 = pl.select(lv1)357df2 = pl.select(lv2)358359# `cast()` from the Python API should give the same results.360assert_frame_equal(361pl.concat(362[363df1.cast(expected.dtype),364df2.cast(expected.dtype),365]366),367expected.to_frame(),368)369370files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]371372df1.write_parquet(files[0])373df2.write_parquet(files[1])374375for f in files:376f.seek(0)377378# Note: Schema is taken from the first file379380if cast_options is not None:381q = pl.scan_parquet(files)382383with pytest.raises(pl.exceptions.SchemaError, match=r"hint: .*pass"):384q.collect()385386assert_frame_equal(387pl.scan_parquet(files, cast_options=cast_options).collect(),388expected.to_frame(),389)390391392def test_scan_cast_options_forbid_int_downcast() -> None:393# Test to ensure that passing `integer_cast='upcast'` does not accidentally394# permit casting to smaller integer types.395lv1, lv2 = pl.lit(1, dtype=pl.Int8), pl.lit(2, dtype=pl.Int32)396397files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]398399df1 = pl.select(lv1)400df2 = pl.select(lv2)401402df1.write_parquet(files[0])403df2.write_parquet(files[1])404405for f in files:406f.seek(0)407408q = pl.scan_parquet(files)409410with pytest.raises(pl.exceptions.SchemaError):411q.collect()412413for f in files:414f.seek(0)415416q = pl.scan_parquet(417files,418cast_options=pl.ScanCastOptions(integer_cast="upcast"),419)420421with pytest.raises(pl.exceptions.SchemaError):422q.collect()423424425def test_scan_cast_options_extra_struct_fields() -> None:426cast_options = pl.ScanCastOptions(extra_struct_fields="ignore")427428expected = pl.Series([{"a": 1}, {"a": 2}], dtype=pl.Struct({"a": pl.Int32}))429expected = expected.alias("literal")430431lv1, lv2 = (432pl.lit({"a": 1}, dtype=pl.Struct({"a": pl.Int32})),433pl.lit(434{"a": 2, "extra_field": 1},435dtype=pl.Struct({"a": pl.Int32, "extra_field": pl.Int32}),436),437)438439files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]440441df1 = pl.select(lv1)442df2 = pl.select(lv2)443444df1.write_parquet(files[0])445df2.write_parquet(files[1])446447for f in files:448f.seek(0)449450q = pl.scan_parquet(files)451452with pytest.raises(pl.exceptions.SchemaError, match=r"hint: specify .*or pass"):453q.collect()454455assert_frame_equal(456pl.scan_parquet(files, cast_options=cast_options).collect(),457expected.to_frame(),458)459460461def test_cast_options_ignore_extra_columns() -> None:462files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]463464pl.DataFrame({"a": 1}).write_parquet(files[0])465pl.DataFrame({"a": 2, "b": 1}).write_parquet(files[1])466467with pytest.raises(468pl.exceptions.SchemaError,469match=r"extra column in file outside of expected schema: b, hint: specify.* or pass",470):471pl.scan_parquet(files, schema={"a": pl.Int64}).collect()472473assert_frame_equal(474pl.scan_parquet(475files,476schema={"a": pl.Int64},477extra_columns="ignore",478).collect(),479pl.DataFrame({"a": [1, 2]}),480)481482483@pytest.mark.parametrize(484("scan_func", "write_func"),485[486(pl.scan_parquet, pl.DataFrame.write_parquet),487# TODO: Fix for all other formats488# (pl.scan_ipc, pl.DataFrame.write_ipc),489# (pl.scan_csv, pl.DataFrame.write_csv),490# (pl.scan_ndjson, pl.DataFrame.write_ndjson),491],492)493def test_scan_cast_options_extra_columns(494scan_func: Callable[[Any], pl.LazyFrame],495write_func: Callable[[pl.DataFrame, io.BytesIO], None],496) -> None:497dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "b": 2, "c": 2})]498files = [io.BytesIO(), io.BytesIO()]499500write_func(dfs[0], files[0])501write_func(dfs[1], files[1])502503with pytest.raises(504pl.exceptions.SchemaError,505match=r"extra column in file outside of expected schema: c, hint: ",506):507scan_func(files).collect()508509assert_frame_equal(510scan_func(files, extra_columns="ignore").collect(), # type: ignore[call-arg]511pl.DataFrame({"a": [1, 2], "b": [1, 2]}),512)513514515@pytest.mark.parametrize("float_dtype", sorted(FLOAT_DTYPES, key=repr))516def test_scan_cast_options_integer_to_float(float_dtype: pl.DataType) -> None:517df = pl.DataFrame({"a": [1]}, schema={"a": pl.Int64})518f = io.BytesIO()519df.write_parquet(f)520521f.seek(0)522523assert_frame_equal(524pl.scan_parquet(f).collect(),525pl.DataFrame({"a": [1]}, schema={"a": pl.Int64}),526)527528q = pl.scan_parquet(f, schema={"a": float_dtype})529530with pytest.raises(SchemaError):531q.collect()532533f.seek(0)534535assert_frame_equal(536pl.scan_parquet(537f,538schema={"a": float_dtype},539cast_options=pl.ScanCastOptions(integer_cast="allow-float"),540).collect(),541pl.DataFrame({"a": [1.0]}, schema={"a": float_dtype}),542)543544545