Path: blob/main/py-polars/tests/unit/io/test_skip_batch_predicate.py
6939 views
from __future__ import annotations12import contextlib3import datetime4from typing import TYPE_CHECKING, Any, TypedDict56from hypothesis import Phase, given, settings78import polars as pl9from polars.meta import get_index_type10from polars.testing import assert_frame_equal, assert_series_equal11from polars.testing.parametric.strategies import series1213if TYPE_CHECKING:14from collections.abc import Sequence1516from polars._typing import PythonLiteral171819class Case(TypedDict):20"""A test case for Skip Batch Predicate."""2122min: Any | None23max: Any | None24null_count: int | None25len: int | None26can_skip: bool272829def assert_skp_series(30name: str,31dtype: pl.DataType,32expr: pl.Expr,33cases: Sequence[Case],34) -> None:35sbp = expr._skip_batch_predicate({name: dtype})3637df = pl.DataFrame(38[39pl.Series(f"{name}_min", [i["min"] for i in cases], dtype),40pl.Series(f"{name}_max", [i["max"] for i in cases], dtype),41pl.Series(f"{name}_nc", [i["null_count"] for i in cases], get_index_type()),42pl.Series("len", [i["len"] for i in cases], get_index_type()),43]44)45mask = pl.Series("can_skip", [i["can_skip"] for i in cases], pl.Boolean)4647out = df.select(can_skip=sbp).to_series()48out = out.replace(None, False)4950try:51assert_series_equal(out, mask)52except AssertionError:53print(sbp)54raise555657def test_true_false_predicate() -> None:58true_sbp = pl.lit(True)._skip_batch_predicate({})59false_sbp = pl.lit(False)._skip_batch_predicate({})60null_sbp = pl.lit(None)._skip_batch_predicate({})6162df = pl.DataFrame({"len": [1]})6364out = df.select(65true=true_sbp,66false=false_sbp,67null=null_sbp,68)6970assert_frame_equal(71out,72pl.DataFrame(73{74"true": [False],75"false": [True],76"null": [True],77}78),79)808182def test_equality() -> None:83assert_skp_series(84"a",85pl.Int64(),86pl.col("a") == 5,87[88{"min": 1, "max": 2, "null_count": 0, "len": 42, "can_skip": True},89{"min": 6, "max": 7, "null_count": 0, "len": 42, "can_skip": True},90{"min": 1, "max": 7, "null_count": 0, "len": 42, "can_skip": False},91{"min": None, "max": None, "null_count": 42, "len": 42, "can_skip": True},92],93)9495assert_skp_series(96"a",97pl.Int64(),98pl.col("a") != 0,99[100{"min": 0, "max": 0, "null_count": 6, "len": 7, "can_skip": False},101],102)103104105def test_datetimes() -> None:106d = datetime.datetime(2023, 4, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)107td = datetime.timedelta108109assert_skp_series(110"a",111pl.Datetime(time_zone=datetime.timezone.utc),112pl.col("a") == d,113[114{115"min": d - td(days=2),116"max": d - td(days=1),117"null_count": 0,118"len": 42,119"can_skip": True,120},121{122"min": d + td(days=1),123"max": d - td(days=2),124"null_count": 0,125"len": 42,126"can_skip": True,127},128{"min": d, "max": d, "null_count": 42, "len": 42, "can_skip": True},129{"min": d, "max": d, "null_count": 0, "len": 42, "can_skip": False},130{131"min": d - td(days=2),132"max": d + td(days=2),133"null_count": 0,134"len": 42,135"can_skip": False,136},137{138"min": d + td(days=1),139"max": None,140"null_count": None,141"len": None,142"can_skip": True,143},144],145)146147148@given(149s=series(150name="x",151min_size=1,152),153)154@settings(155report_multiple_bugs=False,156phases=(Phase.explicit, Phase.reuse, Phase.generate, Phase.target, Phase.explain),157)158def test_skip_batch_predicate_parametric(s: pl.Series) -> None:159name = "x"160dtype = s.dtype161162value_a = s.slice(0, 1)163164lit_a = pl.lit(value_a[0], dtype)165166exprs = [167pl.col.x == lit_a,168pl.col.x != lit_a,169pl.col.x.eq_missing(lit_a),170pl.col.x.ne_missing(lit_a),171pl.col.x.is_null(),172pl.col.x.is_not_null(),173]174175try:176_ = s > value_a177exprs += [178pl.col.x > lit_a,179pl.col.x >= lit_a,180pl.col.x < lit_a,181pl.col.x <= lit_a,182pl.col.x.is_in(pl.Series([None, value_a[0]], dtype=dtype)),183]184185if s.len() > 1:186value_b = s.slice(1, 1)187lit_b = pl.lit(value_b[0], dtype)188189exprs += [190pl.col.x.is_between(lit_a, lit_b),191pl.col.x.is_in(pl.Series([value_a[0], value_b[0]], dtype=dtype)),192]193except Exception as _:194pass195196for expr in exprs:197sbp = expr._skip_batch_predicate({name: dtype})198199if sbp is None:200continue201202mins: list[PythonLiteral | None] = [None]203with contextlib.suppress(Exception):204mins = [s.min()]205206maxs: list[PythonLiteral | None] = [None]207with contextlib.suppress(Exception):208maxs = [s.max()]209210null_counts = [s.null_count()]211lengths = [s.len()]212213df = pl.DataFrame(214[215pl.Series(f"{name}_min", mins, dtype),216pl.Series(f"{name}_max", maxs, dtype),217pl.Series(f"{name}_nc", null_counts, get_index_type()),218pl.Series("len", lengths, get_index_type()),219]220)221222can_skip = df.select(can_skip=sbp).fill_null(False).to_series()[0]223if can_skip:224try:225assert s.to_frame().filter(expr).height == 0226except Exception as _:227print(expr)228print(sbp)229print(df)230print(s.to_frame().filter(expr))231232raise233234235