Path: blob/main/py-polars/tests/unit/io/test_io_plugin.py
6939 views
from __future__ import annotations12import datetime3import io4import subprocess5import sys6from typing import TYPE_CHECKING78import numpy as np9import pytest1011import polars as pl12from polars.io.plugins import register_io_source13from polars.testing import assert_frame_equal, assert_series_equal1415if TYPE_CHECKING:16from collections.abc import Iterator171819def test_io_plugin_predicate_no_serialization_21130() -> None:20def custom_io() -> pl.LazyFrame:21def source_generator(22with_columns: list[str] | None,23predicate: pl.Expr | None,24n_rows: int | None,25batch_size: int | None,26) -> Iterator[pl.DataFrame]:27df = pl.DataFrame(28{"json_val": ['{"a":"1"}', None, '{"a":2}', '{"a":2.1}', '{"a":true}']}29)30if predicate is not None:31df = df.filter(predicate)32if batch_size and df.height > batch_size:33yield from df.iter_slices(n_rows=batch_size)34else:35yield df3637return register_io_source(38io_source=source_generator, schema={"json_val": pl.String}39)4041lf = custom_io()42assert lf.filter(43pl.col("json_val").str.json_path_match("$.a").is_in(["1"])44).collect().to_dict(as_series=False) == {"json_val": ['{"a":"1"}']}454647def test_defer_validate_true() -> None:48lf = pl.defer(49lambda: pl.DataFrame({"a": np.ones(3)}),50schema={"a": pl.Boolean},51validate_schema=True,52)53with pytest.raises(pl.exceptions.SchemaError):54lf.collect()555657@pytest.mark.may_fail_cloud58@pytest.mark.may_fail_auto_streaming # IO plugin validate=False schema mismatch59def test_defer_validate_false() -> None:60lf = pl.defer(61lambda: pl.DataFrame({"a": np.ones(3)}),62schema={"a": pl.Boolean},63validate_schema=False,64)65assert lf.collect().to_dict(as_series=False) == {"a": [1.0, 1.0, 1.0]}666768def test_empty_iterator_io_plugin() -> None:69def _io_source(70with_columns: list[str] | None,71predicate: pl.Expr | None,72n_rows: int | None,73batch_size: int | None,74) -> Iterator[pl.DataFrame]:75yield from []7677schema = pl.Schema([("a", pl.Int64)])78df = register_io_source(_io_source, schema=schema)79assert df.collect().schema == schema808182def test_scan_lines() -> None:83def scan_lines(f: io.BytesIO) -> pl.LazyFrame:84schema = pl.Schema({"lines": pl.String()})8586def generator(87with_columns: list[str] | None,88predicate: pl.Expr | None,89n_rows: int | None,90batch_size: int | None,91) -> Iterator[pl.DataFrame]:92x = f93if batch_size is None:94batch_size = 100_0009596batch_lines: list[str] = []97while n_rows != 0:98batch_lines.clear()99remaining_rows = batch_size100if n_rows is not None:101remaining_rows = min(remaining_rows, n_rows)102n_rows -= remaining_rows103104while remaining_rows != 0 and (line := x.readline().rstrip()):105if isinstance(line, str):106batch_lines += [batch_lines]107else:108batch_lines += [line.decode()]109remaining_rows -= 1110111df = pl.Series("lines", batch_lines, pl.String()).to_frame()112113if with_columns is not None:114df = df.select(with_columns)115if predicate is not None:116df = df.filter(predicate)117118yield df119120if remaining_rows != 0:121break122123return register_io_source(io_source=generator, schema=schema)124125text = """126Hello127This is some text128It is spread over multiple lines129This allows it to read into multiple rows.130""".strip()131f = io.BytesIO(bytes(text, encoding="utf-8"))132133assert_series_equal(134scan_lines(f).collect().to_series(),135pl.Series("lines", text.splitlines(), pl.String()),136)137138139@pytest.mark.may_fail_cloud140@pytest.mark.may_fail_auto_streaming # IO plugin validate=False schema mismatch141def test_datetime_io_predicate_pushdown_21790() -> None:142recorded: dict[str, pl.Expr | None] = {"predicate": None}143df = pl.DataFrame(144{145"timestamp": [146datetime.datetime(2024, 1, 1, 0),147datetime.datetime(2024, 1, 3, 0),148]149}150)151152def _source(153with_columns: list[str] | None,154predicate: pl.Expr | None,155n_rows: int | None,156batch_size: int | None,157) -> Iterator[pl.DataFrame]:158# capture the predicate passed in159recorded["predicate"] = predicate160inner_df = df.clone()161if with_columns is not None:162inner_df = inner_df.select(with_columns)163if predicate is not None:164inner_df = inner_df.filter(predicate)165166yield inner_df167168schema = {"timestamp": pl.Datetime(time_unit="ns")}169lf = register_io_source(io_source=_source, schema=schema)170171cutoff = datetime.datetime(2024, 1, 4)172expr = pl.col("timestamp") < cutoff173filtered_df = lf.filter(expr).collect()174175pushed_predicate = recorded["predicate"]176assert pushed_predicate is not None177assert_series_equal(filtered_df.to_series(), df.filter(expr).to_series())178179# check the expression directly180dt_val, column_cast = pushed_predicate.meta.pop()181# Extract the datetime value from the expression182assert pl.DataFrame({}).select(dt_val).item() == cutoff183184column = column_cast.meta.pop()[0]185assert column.meta == pl.col("timestamp")186187188@pytest.mark.parametrize(("validate"), [(True), (False)])189def test_reordered_columns_22731(validate: bool) -> None:190def my_scan() -> pl.LazyFrame:191schema = pl.Schema({"a": pl.Int64, "b": pl.Int64})192193def source_generator(194with_columns: list[str] | None,195predicate: pl.Expr | None,196n_rows: int | None,197batch_size: int | None,198) -> Iterator[pl.DataFrame]:199df = pl.DataFrame({"a": [1, 2, 3], "b": [42, 13, 37]})200201if n_rows is not None:202df = df.head(min(n_rows, df.height))203204maxrows = 1205if batch_size is not None:206maxrows = batch_size207208while df.height > 0:209maxrows = min(maxrows, df.height)210cur = df.head(maxrows)211df = df.slice(maxrows)212213if predicate is not None:214cur = cur.filter(predicate)215if with_columns is not None:216cur = cur.select(with_columns)217218yield cur219220return register_io_source(221io_source=source_generator, schema=schema, validate_schema=validate222)223224expected_select = pl.DataFrame({"b": [42, 13, 37], "a": [1, 2, 3]})225assert_frame_equal(my_scan().select("b", "a").collect(), expected_select)226227expected_ri = pl.DataFrame({"b": [42, 13, 37], "a": [1, 2, 3]}).with_row_index()228assert_frame_equal(229my_scan().select("b", "a").with_row_index().collect(),230expected_ri,231)232233expected_with_columns = pl.DataFrame({"a": [1, 2, 3], "b": [42, 13, 37]})234assert_frame_equal(235my_scan().with_columns("b", "a").collect(), expected_with_columns236)237238239def test_io_plugin_reentrant_deadlock() -> None:240out = subprocess.check_output(241[242sys.executable,243"-c",244"""\245from __future__ import annotations246247import os248import sys249250os.environ["POLARS_MAX_THREADS"] = "1"251252import polars as pl253from polars.io.plugins import register_io_source254255assert pl.thread_pool_size() == 1256257n = 3258i = 0259260261def reentrant(262with_columns: list[str] | None,263predicate: pl.Expr | None,264n_rows: int | None,265batch_size: int | None,266):267global i268269df = pl.DataFrame({"x": 1})270271if i < n:272i += 1273yield register_io_source(io_source=reentrant, schema={"x": pl.Int64}).collect()274275yield df276277278register_io_source(io_source=reentrant, schema={"x": pl.Int64}).collect()279280print("OK", end="", file=sys.stderr)281""",282],283stderr=subprocess.STDOUT,284timeout=7,285)286287assert out == b"OK"288289290def test_io_plugin_categorical_24172() -> None:291schema = {"cat": pl.Categorical}292293df = pl.concat(294[295pl.DataFrame({"cat": ["X", "Y"]}, schema=schema),296pl.DataFrame({"cat": ["X", "Y"]}, schema=schema),297],298rechunk=False,299)300301assert df.n_chunks() == 2302303assert_frame_equal(304register_io_source(lambda *_: iter([df]), schema=df.schema).collect(),305df,306)307308309