Path: blob/main/py-polars/tests/unit/io/test_scan_options.py
6939 views
from __future__ import annotations12import io3from datetime import datetime4from typing import IO, Any, Callable5from zoneinfo import ZoneInfo67import pytest89import polars as pl10from polars.testing import assert_frame_equal111213@pytest.mark.parametrize(14("literal_values", "expected", "cast_options"),15[16(17(pl.lit(1, dtype=pl.Int64), pl.lit(2, dtype=pl.Int32)),18pl.Series([1, 2], dtype=pl.Int64),19pl.ScanCastOptions(integer_cast="upcast"),20),21(22(pl.lit(1.0, dtype=pl.Float64), pl.lit(2.0, dtype=pl.Float32)),23pl.Series([1, 2], dtype=pl.Float64),24pl.ScanCastOptions(float_cast="upcast"),25),26(27(pl.lit(1.0, dtype=pl.Float32), pl.lit(2.0, dtype=pl.Float64)),28pl.Series([1, 2], dtype=pl.Float32),29pl.ScanCastOptions(float_cast=["upcast", "downcast"]),30),31(32(33pl.lit(datetime(2025, 1, 1), dtype=pl.Datetime(time_unit="ms")),34pl.lit(datetime(2025, 1, 2), dtype=pl.Datetime(time_unit="ns")),35),36pl.Series(37[datetime(2025, 1, 1), datetime(2025, 1, 2)],38dtype=pl.Datetime(time_unit="ms"),39),40pl.ScanCastOptions(datetime_cast="nanosecond-downcast"),41),42(43(44pl.lit(45datetime(2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")),46dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Amsterdam"),47),48pl.lit(49datetime(2025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")),50dtype=pl.Datetime(time_unit="ns", time_zone="Australia/Sydney"),51),52),53pl.Series(54[55datetime(2025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")),56datetime(2025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")),57],58dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Amsterdam"),59),60pl.ScanCastOptions(61datetime_cast=["nanosecond-downcast", "convert-timezone"]62),63),64(65( # We also test nested primitive upcast policy with this one66pl.lit(67{"a": [[1]], "b": 1},68dtype=pl.Struct(69{"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}70),71),72pl.lit(73{"a": [[2]]},74dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int8, 1))}),75),76),77pl.Series(78[{"a": [[1]], "b": 1}, {"a": [[2]], "b": None}],79dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}),80),81pl.ScanCastOptions(82integer_cast="upcast",83missing_struct_fields="insert",84),85),86(87( # Test same set of struct fields but in different order88pl.lit(89{"a": [[1]], "b": 1},90dtype=pl.Struct(91{"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}92),93),94pl.lit(95{"b": None, "a": [[2]]},96dtype=pl.Struct(97{"b": pl.Int32, "a": pl.List(pl.Array(pl.Int32, 1))}98),99),100),101pl.Series(102[{"a": [[1]], "b": 1}, {"a": [[2]], "b": None}],103dtype=pl.Struct({"a": pl.List(pl.Array(pl.Int32, 1)), "b": pl.Int32}),104),105None,106),107# Test logical (datetime) type under list108(109(110pl.lit(111[112{113"field": datetime(1142025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")115)116}117],118dtype=pl.List(119pl.Struct(120{121"field": pl.Datetime(122time_unit="ms", time_zone="Europe/Amsterdam"123)124}125)126),127),128pl.lit(129[130{131"field": datetime(1322025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")133)134}135],136dtype=pl.List(137pl.Struct(138{139"field": pl.Datetime(140time_unit="ns", time_zone="Australia/Sydney"141)142}143)144),145),146),147pl.Series(148[149[150{151"field": datetime(1522025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")153)154}155],156[157{158"field": datetime(1592025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")160)161}162],163],164dtype=pl.List(165pl.Struct(166{167"field": pl.Datetime(168time_unit="ms", time_zone="Europe/Amsterdam"169)170}171)172),173),174pl.ScanCastOptions(175datetime_cast=["nanosecond-downcast", "convert-timezone"]176),177),178(179(180pl.lit(181[182{183"field": datetime(1842025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")185)186}187],188dtype=pl.Array(189pl.Struct(190{191"field": pl.Datetime(192time_unit="ms", time_zone="Europe/Amsterdam"193)194}195),196shape=1,197),198),199pl.lit(200[201{202"field": datetime(2032025, 1, 2, tzinfo=ZoneInfo("Australia/Sydney")204)205}206],207dtype=pl.Array(208pl.Struct(209{210"field": pl.Datetime(211time_unit="ns", time_zone="Australia/Sydney"212)213}214),215shape=1,216),217),218),219pl.Series(220[221[222{223"field": datetime(2242025, 1, 1, tzinfo=ZoneInfo("Europe/Amsterdam")225)226}227],228[229{230"field": datetime(2312025, 1, 1, 14, tzinfo=ZoneInfo("Europe/Amsterdam")232)233}234],235],236dtype=pl.Array(237pl.Struct(238{239"field": pl.Datetime(240time_unit="ms", time_zone="Europe/Amsterdam"241)242}243),244shape=1,245),246),247pl.ScanCastOptions(248datetime_cast=["nanosecond-downcast", "convert-timezone"]249),250),251# Test outer validity252(253(254pl.lit(255None,256dtype=pl.List(257pl.Struct(258{259"field": pl.Datetime(260time_unit="ms", time_zone="Europe/Amsterdam"261)262}263)264),265),266pl.lit(267[None],268dtype=pl.List(269pl.Struct(270{271"field": pl.Datetime(272time_unit="ns", time_zone="Australia/Sydney"273)274}275)276),277),278),279pl.Series(280[None, [None]],281dtype=pl.List(282pl.Struct(283{284"field": pl.Datetime(285time_unit="ms", time_zone="Europe/Amsterdam"286)287}288)289),290),291pl.ScanCastOptions(292datetime_cast=["nanosecond-downcast", "convert-timezone"]293),294),295(296(297pl.lit(298None,299dtype=pl.Array(300pl.Struct(301{302"field": pl.Datetime(303time_unit="ms", time_zone="Europe/Amsterdam"304)305}306),307shape=1,308),309),310pl.lit(311[None],312dtype=pl.Array(313pl.Struct(314{315"field": pl.Datetime(316time_unit="ns", time_zone="Australia/Sydney"317)318}319),320shape=1,321),322),323),324pl.Series(325[None, [None]],326dtype=pl.Array(327pl.Struct(328{329"field": pl.Datetime(330time_unit="ms", time_zone="Europe/Amsterdam"331)332}333),334shape=1,335),336),337pl.ScanCastOptions(338datetime_cast=["nanosecond-downcast", "convert-timezone"]339),340),341],342)343def test_scan_cast_options(344literal_values: tuple[pl.Expr, pl.Expr],345expected: pl.Series,346cast_options: pl.ScanCastOptions | None,347) -> None:348expected = expected.alias("literal")349lv1, lv2 = literal_values350351df1 = pl.select(lv1)352df2 = pl.select(lv2)353354# `cast()` from the Python API should give the same results.355assert_frame_equal(356pl.concat(357[358df1.cast(expected.dtype),359df2.cast(expected.dtype),360]361),362expected.to_frame(),363)364365files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]366367df1.write_parquet(files[0])368df2.write_parquet(files[1])369370for f in files:371f.seek(0)372373# Note: Schema is taken from the first file374375if cast_options is not None:376q = pl.scan_parquet(files)377378with pytest.raises(pl.exceptions.SchemaError, match=r"hint: .*pass"):379q.collect()380381assert_frame_equal(382pl.scan_parquet(files, cast_options=cast_options).collect(),383expected.to_frame(),384)385386387def test_scan_cast_options_forbid_int_downcast() -> None:388# Test to ensure that passing `integer_cast='upcast'` does not accidentally389# permit casting to smaller integer types.390lv1, lv2 = pl.lit(1, dtype=pl.Int8), pl.lit(2, dtype=pl.Int32)391392files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]393394df1 = pl.select(lv1)395df2 = pl.select(lv2)396397df1.write_parquet(files[0])398df2.write_parquet(files[1])399400for f in files:401f.seek(0)402403q = pl.scan_parquet(files)404405with pytest.raises(pl.exceptions.SchemaError):406q.collect()407408for f in files:409f.seek(0)410411q = pl.scan_parquet(412files,413cast_options=pl.ScanCastOptions(integer_cast="upcast"),414)415416with pytest.raises(pl.exceptions.SchemaError):417q.collect()418419420def test_scan_cast_options_extra_struct_fields() -> None:421cast_options = pl.ScanCastOptions(extra_struct_fields="ignore")422423expected = pl.Series([{"a": 1}, {"a": 2}], dtype=pl.Struct({"a": pl.Int32}))424expected = expected.alias("literal")425426lv1, lv2 = (427pl.lit({"a": 1}, dtype=pl.Struct({"a": pl.Int32})),428pl.lit(429{"a": 2, "extra_field": 1},430dtype=pl.Struct({"a": pl.Int32, "extra_field": pl.Int32}),431),432)433434files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]435436df1 = pl.select(lv1)437df2 = pl.select(lv2)438439df1.write_parquet(files[0])440df2.write_parquet(files[1])441442for f in files:443f.seek(0)444445q = pl.scan_parquet(files)446447with pytest.raises(pl.exceptions.SchemaError, match=r"hint: specify .*or pass"):448q.collect()449450assert_frame_equal(451pl.scan_parquet(files, cast_options=cast_options).collect(),452expected.to_frame(),453)454455456def test_cast_options_ignore_extra_columns() -> None:457files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]458459pl.DataFrame({"a": 1}).write_parquet(files[0])460pl.DataFrame({"a": 2, "b": 1}).write_parquet(files[1])461462with pytest.raises(463pl.exceptions.SchemaError,464match="extra column in file outside of expected schema: b, hint: specify.* or pass",465):466pl.scan_parquet(files, schema={"a": pl.Int64}).collect()467468assert_frame_equal(469pl.scan_parquet(470files,471schema={"a": pl.Int64},472extra_columns="ignore",473).collect(),474pl.DataFrame({"a": [1, 2]}),475)476477478@pytest.mark.parametrize(479("scan_func", "write_func"),480[481(pl.scan_parquet, pl.DataFrame.write_parquet),482# TODO: Fix for all other formats483# (pl.scan_ipc, pl.DataFrame.write_ipc),484# (pl.scan_csv, pl.DataFrame.write_csv),485# (pl.scan_ndjson, pl.DataFrame.write_ndjson),486],487)488def test_scan_extra_columns(489scan_func: Callable[[Any], pl.LazyFrame],490write_func: Callable[[pl.DataFrame, io.BytesIO], None],491) -> None:492dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "b": 2, "c": 2})]493files = [io.BytesIO(), io.BytesIO()]494495write_func(dfs[0], files[0])496write_func(dfs[1], files[1])497498with pytest.raises(499pl.exceptions.SchemaError,500match=r"extra column in file outside of expected schema: c, hint: ",501):502scan_func(files).collect()503504assert_frame_equal(505scan_func(files, extra_columns="ignore").collect(), # type: ignore[call-arg]506pl.DataFrame({"a": [1, 2], "b": [1, 2]}),507)508509510