Path: blob/main/py-polars/tests/unit/io/test_multiscan.py
6939 views
from __future__ import annotations12import io3from functools import partial4from typing import IO, TYPE_CHECKING, Any, Callable56import pyarrow.parquet as pq7import pytest8from hypothesis import given9from hypothesis import strategies as st1011import polars as pl12from polars.meta.index_type import get_index_type13from polars.testing import assert_frame_equal1415if TYPE_CHECKING:16from pathlib import Path1718SCAN_AND_WRITE_FUNCS = [19(pl.scan_ipc, pl.DataFrame.write_ipc),20(pl.scan_parquet, pl.DataFrame.write_parquet),21(pl.scan_csv, pl.DataFrame.write_csv),22(pl.scan_ndjson, pl.DataFrame.write_ndjson),23]242526@pytest.mark.write_disk27@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)28def test_include_file_paths(tmp_path: Path, scan: Any, write: Any) -> None:29a_path = tmp_path / "a"30b_path = tmp_path / "b"3132write(pl.DataFrame({"a": [5, 10]}), a_path)33write(pl.DataFrame({"a": [1996]}), b_path)3435out = scan([a_path, b_path], include_file_paths="f")3637assert_frame_equal(38out.collect(),39pl.DataFrame(40{41"a": [5, 10, 1996],42"f": [str(a_path), str(a_path), str(b_path)],43}44),45)464748@pytest.mark.parametrize(49("scan", "write", "ext", "supports_missing_columns", "supports_hive_partitioning"),50[51(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc", False, True),52(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet", True, True),53(pl.scan_csv, pl.DataFrame.write_csv, "csv", False, False),54(pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl", False, False),55],56)57@pytest.mark.parametrize("missing_column", [False, True])58@pytest.mark.parametrize("row_index", [False, True])59@pytest.mark.parametrize("include_file_paths", [False, True])60@pytest.mark.parametrize("hive", [False, True])61@pytest.mark.parametrize("col", [False, True])62@pytest.mark.write_disk63def test_multiscan_projection(64tmp_path: Path,65scan: Callable[..., pl.LazyFrame],66write: Callable[[pl.DataFrame, Path], Any],67ext: str,68supports_missing_columns: bool,69supports_hive_partitioning: bool,70missing_column: bool,71row_index: bool,72include_file_paths: bool,73hive: bool,74col: bool,75) -> None:76a = pl.DataFrame({"col": [5, 10, 1996]})77b = pl.DataFrame({"col": [13, 37]})7879if missing_column and supports_missing_columns:80a = a.with_columns(missing=pl.Series([420, 2000, 9]))8182a_path: Path83b_path: Path84multiscan_path: Path8586if hive and supports_hive_partitioning:87(tmp_path / "hive_col=0").mkdir()88a_path = tmp_path / "hive_col=0" / f"a.{ext}"89(tmp_path / "hive_col=1").mkdir()90b_path = tmp_path / "hive_col=1" / f"b.{ext}"9192multiscan_path = tmp_path9394else:95a_path = tmp_path / f"a.{ext}"96b_path = tmp_path / f"b.{ext}"9798multiscan_path = tmp_path / f"*.{ext}"99100write(a, a_path)101write(b, b_path)102103base_projection = []104if missing_column and supports_missing_columns:105base_projection += ["missing"]106if row_index:107base_projection += ["row_index"]108if include_file_paths:109base_projection += ["file_path"]110if hive and supports_hive_partitioning:111base_projection += ["hive_col"]112if col:113base_projection += ["col"]114115ifp = "file_path" if include_file_paths else None116ri = "row_index" if row_index else None117118args = {119"missing_columns": "insert" if missing_column else "raise",120"include_file_paths": ifp,121"row_index_name": ri,122"hive_partitioning": hive,123}124125if not supports_missing_columns:126del args["missing_columns"]127if not supports_hive_partitioning:128del args["hive_partitioning"]129130for projection in [131base_projection,132base_projection[::-1],133]:134assert_frame_equal(135scan(multiscan_path, **args).collect(engine="streaming").select(projection),136scan(multiscan_path, **args).select(projection).collect(engine="streaming"),137)138139for remove in range(len(base_projection)):140new_projection = base_projection.copy()141new_projection.pop(remove)142143for projection in [144new_projection,145new_projection[::-1],146]:147print(projection)148assert_frame_equal(149scan(multiscan_path, **args)150.collect(engine="streaming")151.select(projection),152scan(multiscan_path, **args)153.select(projection)154.collect(engine="streaming"),155)156157158@pytest.mark.parametrize(159("scan", "write", "ext"),160[161(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"),162(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"),163],164)165@pytest.mark.write_disk166def test_multiscan_hive_predicate(167tmp_path: Path,168scan: Callable[..., pl.LazyFrame],169write: Callable[[pl.DataFrame, Path], Any],170ext: str,171) -> None:172a = pl.DataFrame({"col": [5, 10, 1996]})173b = pl.DataFrame({"col": [13, 37]})174c = pl.DataFrame({"col": [3, 5, 2024]})175176(tmp_path / "hive_col=0").mkdir()177a_path = tmp_path / "hive_col=0" / f"0.{ext}"178(tmp_path / "hive_col=1").mkdir()179b_path = tmp_path / "hive_col=1" / f"0.{ext}"180(tmp_path / "hive_col=2").mkdir()181c_path = tmp_path / "hive_col=2" / f"0.{ext}"182183multiscan_path = tmp_path184185write(a, a_path)186write(b, b_path)187write(c, c_path)188189full = scan(multiscan_path).collect(engine="streaming")190full_ri = full.with_row_index("ri", 42)191192last_pred = None193try:194for pred in [195pl.col.hive_col == 0,196pl.col.hive_col == 1,197pl.col.hive_col == 2,198pl.col.hive_col < 2,199pl.col.hive_col > 0,200pl.col.hive_col != 1,201pl.col.hive_col != 3,202pl.col.col == 13,203pl.col.col != 13,204(pl.col.col != 13) & (pl.col.hive_col == 1),205(pl.col.col != 13) & (pl.col.hive_col != 1),206]:207last_pred = pred208assert_frame_equal(209full.filter(pred),210scan(multiscan_path).filter(pred).collect(engine="streaming"),211)212213assert_frame_equal(214full_ri.filter(pred),215scan(multiscan_path)216.with_row_index("ri", 42)217.filter(pred)218.collect(engine="streaming"),219)220except Exception as _:221print(last_pred)222raise223224225@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)226@pytest.mark.write_disk227def test_multiscan_row_index(228tmp_path: Path,229scan: Callable[..., pl.LazyFrame],230write: Callable[[pl.DataFrame, Path], Any],231) -> None:232a = pl.DataFrame({"col": [5, 10, 1996]})233b = pl.DataFrame({"col": [42]})234c = pl.DataFrame({"col": [13, 37]})235236write(a, tmp_path / "a")237write(b, tmp_path / "b")238write(c, tmp_path / "c")239240col = pl.concat([a, b, c]).to_series()241g = tmp_path / "*"242243assert_frame_equal(244scan(g, row_index_name="ri").collect(),245pl.DataFrame(246[247pl.Series("ri", range(6), get_index_type()),248col,249]250),251)252253start = 42254assert_frame_equal(255scan(g, row_index_name="ri", row_index_offset=start).collect(),256pl.DataFrame(257[258pl.Series("ri", range(start, start + 6), get_index_type()),259col,260]261),262)263264start = 42265assert_frame_equal(266scan(g, row_index_name="ri", row_index_offset=start).slice(3, 3).collect(),267pl.DataFrame(268[269pl.Series("ri", range(start + 3, start + 6), get_index_type()),270col.slice(3, 3),271]272),273)274275start = 42276assert_frame_equal(277scan(g, row_index_name="ri", row_index_offset=start)278.filter(pl.col("col") < 15)279.collect(),280pl.DataFrame(281[282pl.Series("ri", [start + 0, start + 1, start + 4], get_index_type()),283pl.Series("col", [5, 10, 13]),284]285),286)287288289@pytest.mark.parametrize(290("scan", "write", "ext"),291[292(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"),293(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"),294pytest.param(295pl.scan_csv,296pl.DataFrame.write_csv,297"csv",298marks=pytest.mark.xfail(299reason="See https://github.com/pola-rs/polars/issues/21211"300),301),302(pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"),303],304)305@pytest.mark.write_disk306def test_schema_mismatch_type_mismatch(307tmp_path: Path,308scan: Callable[..., pl.LazyFrame],309write: Callable[[pl.DataFrame, Path], Any],310ext: str,311) -> None:312a = pl.DataFrame({"xyz_col": [5, 10, 1996]})313b = pl.DataFrame({"xyz_col": ["a", "b", "c"]})314315a_path = tmp_path / f"a.{ext}"316b_path = tmp_path / f"b.{ext}"317318multiscan_path = tmp_path / f"*.{ext}"319320write(a, a_path)321write(b, b_path)322323q = scan(multiscan_path)324325# NDJSON will just parse according to `projected_schema`326cx = (327pytest.raises(pl.exceptions.ComputeError, match="cannot parse 'a' as Int64")328if scan is pl.scan_ndjson329else pytest.raises(330pl.exceptions.SchemaError, # type: ignore[arg-type]331match=(332"data type mismatch for column xyz_col: "333"incoming: String != target: Int64"334),335)336)337338with cx:339q.collect(engine="streaming")340341342@pytest.mark.parametrize(343("scan", "write", "ext"),344[345# (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), # TODO: _346# (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), # TODO: _347pytest.param(348pl.scan_csv,349pl.DataFrame.write_csv,350"csv",351marks=pytest.mark.xfail(352reason="See https://github.com/pola-rs/polars/issues/21211"353),354),355# (pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"), # TODO: _356],357)358@pytest.mark.write_disk359def test_schema_mismatch_order_mismatch(360tmp_path: Path,361scan: Callable[..., pl.LazyFrame],362write: Callable[[pl.DataFrame, Path], Any],363ext: str,364) -> None:365a = pl.DataFrame({"x": [5, 10, 1996], "y": ["a", "b", "c"]})366b = pl.DataFrame({"y": ["x", "y"], "x": [1, 2]})367368a_path = tmp_path / f"a.{ext}"369b_path = tmp_path / f"b.{ext}"370371multiscan_path = tmp_path / f"*.{ext}"372373write(a, a_path)374write(b, b_path)375376q = scan(multiscan_path)377378with pytest.raises(pl.exceptions.SchemaError):379q.collect(engine="streaming")380381382@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)383def test_multiscan_head(384scan: Callable[..., pl.LazyFrame],385write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],386) -> None:387a = io.BytesIO()388b = io.BytesIO()389for f in [a, b]:390write(pl.Series("c1", range(10)).to_frame(), f)391f.seek(0)392393assert_frame_equal(394scan([a, b]).head(5).collect(engine="streaming"),395pl.Series("c1", range(5)).to_frame(),396)397398399@pytest.mark.parametrize(400("scan", "write"),401[402(pl.scan_ipc, pl.DataFrame.write_ipc),403(pl.scan_parquet, pl.DataFrame.write_parquet),404(pl.scan_ndjson, pl.DataFrame.write_ndjson),405(406pl.scan_csv,407pl.DataFrame.write_csv,408),409],410)411def test_multiscan_tail(412scan: Callable[..., pl.LazyFrame],413write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],414) -> None:415a = io.BytesIO()416b = io.BytesIO()417for f in [a, b]:418write(pl.Series("c1", range(10)).to_frame(), f)419f.seek(0)420421assert_frame_equal(422scan([a, b]).tail(5).collect(engine="streaming"),423pl.Series("c1", range(5, 10)).to_frame(),424)425426427@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)428def test_multiscan_slice_middle(429scan: Callable[..., pl.LazyFrame],430write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],431) -> None:432fs = [io.BytesIO() for _ in range(13)]433for f in fs:434write(pl.Series("c1", range(7)).to_frame(), f)435f.seek(0)436437offset = 5 * 7 - 5438expected = (439list(range(2, 7)) # fs[4]440+ list(range(7)) # fs[5]441+ list(range(5)) # fs[6]442)443expected_series = [pl.Series("c1", expected)]444ri_expected_series = [445pl.Series("ri", range(offset, offset + 17), get_index_type())446] + expected_series447448assert_frame_equal(449scan(fs).slice(offset, 17).collect(engine="streaming"),450pl.DataFrame(expected_series),451)452assert_frame_equal(453scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"),454pl.DataFrame(ri_expected_series),455)456457# Negative slices458offset = -(13 * 7 - offset)459assert_frame_equal(460scan(fs).slice(offset, 17).collect(engine="streaming"),461pl.DataFrame(expected_series),462)463assert_frame_equal(464scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"),465pl.DataFrame(ri_expected_series),466)467468469@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)470@given(offset=st.integers(-100, 100), length=st.integers(0, 101))471def test_multiscan_slice_parametric(472scan: Callable[..., pl.LazyFrame],473write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],474offset: int,475length: int,476) -> None:477ref = io.BytesIO()478write(pl.Series("c1", [i % 7 for i in range(13 * 7)]).to_frame(), ref)479ref.seek(0)480481fs = [io.BytesIO() for _ in range(13)]482for f in fs:483write(pl.Series("c1", range(7)).to_frame(), f)484f.seek(0)485486assert_frame_equal(487scan(ref).slice(offset, length).collect(),488scan(fs).slice(offset, length).collect(engine="streaming"),489)490491ref.seek(0)492for f in fs:493f.seek(0)494495assert_frame_equal(496scan(ref, row_index_name="ri", row_index_offset=42)497.slice(offset, length)498.collect(),499scan(fs, row_index_name="ri", row_index_offset=42)500.slice(offset, length)501.collect(engine="streaming"),502)503504assert_frame_equal(505scan(ref, row_index_name="ri", row_index_offset=42)506.slice(offset, length)507.select("ri")508.collect(),509scan(fs, row_index_name="ri", row_index_offset=42)510.slice(offset, length)511.select("ri")512.collect(engine="streaming"),513)514515516@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)517def test_many_files(scan: Any, write: Any) -> None:518f = io.BytesIO()519write(pl.DataFrame({"a": [5, 10, 1996]}), f)520bs = f.getvalue()521522out = scan([bs] * 1023)523524assert_frame_equal(525out.collect(),526pl.DataFrame(527{528"a": [5, 10, 1996] * 1023,529}530),531)532533534def test_deadlock_stop_requested(monkeypatch: Any) -> None:535df = pl.DataFrame(536{537"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],538}539)540541f = io.BytesIO()542df.write_parquet(f, row_group_size=1)543544monkeypatch.setenv("POLARS_MAX_THREADS", "2")545monkeypatch.setenv("POLARS_JOIN_SAMPLE_LIMIT", "1")546547left_fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]548right_fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]549550left = pl.scan_parquet(left_fs) # type: ignore[arg-type]551right = pl.scan_parquet(right_fs) # type: ignore[arg-type]552553left.join(right, pl.col.a == pl.col.a).collect(engine="streaming").height554555556@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)557def test_deadlock_linearize(scan: Any, write: Any) -> None:558df = pl.DataFrame(559{560"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],561}562)563564f = io.BytesIO()565write(df, f)566fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]567lf = scan(fs).head(100)568569assert_frame_equal(570lf.collect(571engine="streaming", optimizations=pl.QueryOptFlags(slice_pushdown=False)572),573pl.concat([df] * 10),574)575576577@pytest.mark.parametrize(578("scan", "write"),579SCAN_AND_WRITE_FUNCS,580)581def test_row_index_filter_22612(scan: Any, write: Any) -> None:582df = pl.DataFrame(583{584"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],585}586)587588f = io.BytesIO()589590if write is pl.DataFrame.write_parquet:591df.write_parquet(f, row_group_size=5)592assert pq.read_metadata(f).num_row_groups == 2593else:594write(df, f)595596for end in range(2, 10):597assert_frame_equal(598scan(f)599.with_row_index()600.filter(pl.col("index") >= end - 2, pl.col("index") <= end)601.collect(),602df.with_row_index().slice(end - 2, 3),603)604605assert_frame_equal(606scan(f)607.with_row_index()608.filter(pl.col("index").is_between(end - 2, end))609.collect(),610df.with_row_index().slice(end - 2, 3),611)612613614@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)615def test_row_index_name_in_file(scan: Any, write: Any) -> None:616f = io.BytesIO()617write(pl.DataFrame({"index": 1}), f)618619with pytest.raises(620pl.exceptions.DuplicateError,621match="cannot add row_index with name 'index': column already exists in file",622):623scan(f).with_row_index().collect()624625626def test_extra_columns_not_ignored_22218() -> None:627dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "c": 2})]628629files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]630631dfs[0].write_parquet(files[0])632dfs[1].write_parquet(files[1])633634with pytest.raises(635pl.exceptions.SchemaError,636match="extra column in file outside of expected schema: c, hint: specify .*or pass",637):638(pl.scan_parquet(files, missing_columns="insert").select(pl.all()).collect())639640assert_frame_equal(641pl.scan_parquet(642files,643missing_columns="insert",644extra_columns="ignore",645)646.select(pl.all())647.collect(),648pl.DataFrame({"a": [1, 2], "b": [1, None]}),649)650651652@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)653def test_scan_null_upcast(scan: Any, write: Any) -> None:654dfs = [655pl.DataFrame({"a": [1, 2, 3]}),656pl.select(a=pl.lit(None, dtype=pl.Null)),657]658659files = [io.BytesIO(), io.BytesIO()]660661write(dfs[0], files[0])662write(dfs[1], files[1])663664# Prevent CSV schema inference from loading as string (it looks at multiple665# files).666if scan is pl.scan_csv:667scan = partial(scan, schema=dfs[0].schema)668669assert_frame_equal(670scan(files).collect(),671pl.DataFrame({"a": [1, 2, 3, None]}),672)673674675@pytest.mark.parametrize(676("scan", "write"),677[678(pl.scan_ipc, pl.DataFrame.write_ipc),679(pl.scan_parquet, pl.DataFrame.write_parquet),680(pl.scan_ndjson, pl.DataFrame.write_ndjson),681],682)683def test_scan_null_upcast_to_nested(scan: Any, write: Any) -> None:684schema = {"a": pl.List(pl.Struct({"field": pl.Int64}))}685686dfs = [687pl.DataFrame(688{"a": [[{"field": 1}], [{"field": 2}], []]},689schema=schema,690),691pl.select(a=pl.lit(None, dtype=pl.Null)),692]693694files = [io.BytesIO(), io.BytesIO()]695696write(dfs[0], files[0])697write(dfs[1], files[1])698699# Prevent CSV schema inference from loading as string (it looks at multiple700# files).701if scan is pl.scan_csv:702scan = partial(scan, schema=schema)703704assert_frame_equal(705scan(files).collect(),706pl.DataFrame(707{"a": [[{"field": 1}], [{"field": 2}], [], None]},708schema=schema,709),710)711712713