Path: blob/main/py-polars/tests/unit/test_row_encoding.py
6939 views
from __future__ import annotations12from decimal import Decimal as D3from typing import TYPE_CHECKING45import pytest6from hypothesis import given78import polars as pl9import polars.selectors as cs10from polars.testing import assert_frame_equal, assert_series_equal11from polars.testing.parametric import dataframes, series12from polars.testing.parametric.strategies.dtype import dtypes13from tests.unit.conftest import FLOAT_DTYPES, INTEGER_DTYPES1415if TYPE_CHECKING:16from typing import Any1718from polars._typing import PolarsDataType1920FIELD_COMBS = [21(descending, nulls_last, False)22for descending in [False, True]23for nulls_last in [False, True]24] + [(None, None, True)]2526FIELD_COMBS_ARGS = [27{28"unordered": unordered,29"descending": descending,30"nulls_last": nulls_last,31}32for descending, nulls_last, unordered in FIELD_COMBS33]343536def roundtrip_re(37df: pl.DataFrame,38*,39unordered: bool = False,40descending: list[bool] | None = None,41nulls_last: list[bool] | None = None,42) -> None:43row_encoded = df._row_encode(44unordered=unordered,45descending=descending,46nulls_last=nulls_last,47)4849if unordered:50return5152names = df.columns53dtypes = df.dtypes54result = row_encoded._row_decode(55names, dtypes, unordered=unordered, descending=descending, nulls_last=nulls_last56).struct.unnest()5758assert_frame_equal(df, result)596061def roundtrip_series_re(62values: pl.series.series.ArrayLike,63dtype: PolarsDataType,64*,65unordered: bool = False,66descending: bool | None = None,67nulls_last: bool | None = False,68) -> None:69descending_lst = None if descending is None else [descending]70nulls_last_lst = None if nulls_last is None else [nulls_last]7172roundtrip_re(73pl.Series("series", values, dtype).to_frame(),74unordered=unordered,75descending=descending_lst,76nulls_last=nulls_last_lst,77)787980@given(81df=dataframes(82excluded_dtypes=[83pl.Categorical,84pl.Decimal, # Bug: see https://github.com/pola-rs/polars/issues/2030885]86)87)88@pytest.mark.parametrize(("descending", "nulls_last", "unordered"), FIELD_COMBS)89def test_row_encoding_parametric(90df: pl.DataFrame,91unordered: bool,92descending: bool | None,93nulls_last: bool | None,94) -> None:95roundtrip_re(96df,97unordered=unordered,98descending=None if descending is None else [descending] * df.width,99nulls_last=None if nulls_last is None else [nulls_last] * df.width,100)101102103@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)104def test_nulls(field: Any) -> None:105roundtrip_series_re([], pl.Null, **field)106roundtrip_series_re([None], pl.Null, **field)107roundtrip_series_re([None] * 2, pl.Null, **field)108roundtrip_series_re([None] * 13, pl.Null, **field)109roundtrip_series_re([None] * 42, pl.Null, **field)110111112@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)113def test_bool(field: Any) -> None:114roundtrip_series_re([], pl.Boolean, **field)115roundtrip_series_re([False], pl.Boolean, **field)116roundtrip_series_re([True], pl.Boolean, **field)117roundtrip_series_re([False, True], pl.Boolean, **field)118roundtrip_series_re([True, False], pl.Boolean, **field)119120121@pytest.mark.parametrize("dtype", INTEGER_DTYPES)122@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)123def test_int(dtype: pl.DataType, field: Any) -> None:124min = pl.select(x=dtype.min()).item() # type: ignore[attr-defined]125max = pl.select(x=dtype.max()).item() # type: ignore[attr-defined]126127roundtrip_series_re([], dtype, **field)128roundtrip_series_re([0], dtype, **field)129roundtrip_series_re([min], dtype, **field)130roundtrip_series_re([max], dtype, **field)131132roundtrip_series_re([1, 2, 3], dtype, **field)133roundtrip_series_re([0, 1, 2, 3], dtype, **field)134roundtrip_series_re([min, 0, max], dtype, **field)135136137@pytest.mark.parametrize("dtype", FLOAT_DTYPES)138@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)139def test_float(dtype: pl.DataType, field: Any) -> None:140inf = float("inf")141inf_b = float("-inf")142143roundtrip_series_re([], dtype, **field)144roundtrip_series_re([0.0], dtype, **field)145roundtrip_series_re([inf], dtype, **field)146roundtrip_series_re([-inf_b], dtype, **field)147148roundtrip_series_re([1.0, 2.0, 3.0], dtype, **field)149roundtrip_series_re([0.0, 1.0, 2.0, 3.0], dtype, **field)150roundtrip_series_re([inf, 0, -inf_b], dtype, **field)151152153@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)154def test_str(field: Any) -> None:155dtype = pl.String156roundtrip_series_re([], dtype, **field)157roundtrip_series_re([""], dtype, **field)158159roundtrip_series_re(["a", "b", "c"], dtype, **field)160roundtrip_series_re(["", "a", "b", "c"], dtype, **field)161162roundtrip_series_re(163["different", "length", "strings"],164dtype,165**field,166)167roundtrip_series_re(168["different", "", "length", "", "strings"],169dtype,170**field,171)172173174@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)175def test_struct(field: Any) -> None:176dtype = pl.Struct({})177roundtrip_series_re([], dtype, **field)178roundtrip_series_re([None], dtype, **field)179roundtrip_series_re([{}], dtype, **field)180roundtrip_series_re([{}, {}, {}], dtype, **field)181roundtrip_series_re([{}, None, {}], dtype, **field)182183dtype = pl.Struct({"x": pl.Int32})184roundtrip_series_re([{"x": 1}], dtype, **field)185roundtrip_series_re([None], dtype, **field)186roundtrip_series_re([{"x": 1}] * 3, dtype, **field)187roundtrip_series_re([{"x": 1}, {"x": None}, None], dtype, **field)188189dtype = pl.Struct({"x": pl.Int32, "y": pl.Int32})190roundtrip_series_re(191[{"x": 1}, {"y": 2}],192dtype,193**field,194)195roundtrip_series_re([None], dtype, **field)196197198@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)199def test_list(field: Any) -> None:200dtype = pl.List(pl.Int32)201roundtrip_series_re([], dtype, **field)202roundtrip_series_re([[]], dtype, **field)203roundtrip_series_re([[1], [2]], dtype, **field)204roundtrip_series_re([[1, 2], [3]], dtype, **field)205roundtrip_series_re([[1, 2], [], [3]], dtype, **field)206roundtrip_series_re([None, [1, 2], None, [], [3]], dtype, **field)207208dtype = pl.List(pl.String)209roundtrip_series_re([], dtype, **field)210roundtrip_series_re([[]], dtype, **field)211roundtrip_series_re([[""], [""]], dtype, **field)212roundtrip_series_re([["abc"], ["xyzw"]], dtype, **field)213roundtrip_series_re([["x", "yx"], ["abc"]], dtype, **field)214roundtrip_series_re([["wow", "this is"], [], ["cool"]], dtype, **field)215roundtrip_series_re(216[None, ["very", "very"], None, [], ["cool"]],217dtype,218**field,219)220221222@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)223def test_array(field: Any) -> None:224dtype = pl.Array(pl.Int32, 0)225roundtrip_series_re([], dtype, **field)226roundtrip_series_re([[]], dtype, **field)227roundtrip_series_re([None, [], None], dtype, **field)228roundtrip_series_re([None], dtype, **field)229230dtype = pl.Array(pl.Int32, 2)231roundtrip_series_re([], dtype, **field)232roundtrip_series_re([[5, 6]], dtype, **field)233roundtrip_series_re([[1, 2], [2, 3]], dtype, **field)234roundtrip_series_re([[1, 2], [3, 7]], dtype, **field)235roundtrip_series_re([[1, 2], [13, 11], [3, 7]], dtype, **field)236roundtrip_series_re(237[None, [1, 2], None, [13, 11], [5, 7]],238dtype,239**field,240)241242dtype = pl.Array(pl.String, 2)243roundtrip_series_re([], dtype, **field)244roundtrip_series_re([["a", "b"]], dtype, **field)245roundtrip_series_re([["", ""], ["", "a"]], dtype, **field)246roundtrip_series_re([["abc", "def"], ["ghi", "xyzw"]], dtype, **field)247roundtrip_series_re([["x", "yx"], ["abc", "xxx"]], dtype, **field)248roundtrip_series_re(249[["wow", "this is"], ["soo", "so"], ["veryyy", "cool"]],250dtype,251**field,252)253roundtrip_series_re(254[None, ["very", "very"], None, [None, None], ["verryy", "cool"]],255dtype,256**field,257)258259260@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)261@pytest.mark.parametrize("precision", range(1, 38))262def test_decimal(field: Any, precision: int) -> None:263dtype = pl.Decimal(precision=precision, scale=0)264roundtrip_series_re([], dtype, **field)265roundtrip_series_re([None], dtype, **field)266roundtrip_series_re([D("1")], dtype, **field)267roundtrip_series_re([D("-1")], dtype, **field)268roundtrip_series_re([D("9" * precision)], dtype, **field)269roundtrip_series_re([D("-" + "9" * precision)], dtype, **field)270roundtrip_series_re([None, D("-1"), None], dtype, **field)271roundtrip_series_re([D("-1"), D("0"), D("1")], dtype, **field)272273274@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)275def test_enum(field: Any) -> None:276dtype = pl.Enum([])277278roundtrip_series_re([], dtype, **field)279roundtrip_series_re([None], dtype, **field)280roundtrip_series_re([None, None], dtype, **field)281282dtype = pl.Enum(["a", "x", "b"])283284roundtrip_series_re([], dtype, **field)285roundtrip_series_re([None], dtype, **field)286roundtrip_series_re(["a"], dtype, **field)287roundtrip_series_re(["x"], dtype, **field)288roundtrip_series_re(["b"], dtype, **field)289roundtrip_series_re(["b", "x", "a"], dtype, **field)290roundtrip_series_re([None, "b", None], dtype, **field)291roundtrip_series_re([None, "a", None], dtype, **field)292293294@pytest.mark.parametrize("size", [127, 128, 255, 256, 2**15, 2**15 + 1])295@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)296@pytest.mark.slow297def test_large_enum(size: int, field: Any) -> None:298dtype = pl.Enum([str(i) for i in range(size)])299roundtrip_series_re([None, "1"], dtype, **field)300roundtrip_series_re(["1", None], dtype, **field)301302roundtrip_series_re(303[str(i) for i in range(3, size, int(7 * size / (2**8)))], dtype, **field304)305306307@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)308def test_list_arr(field: Any) -> None:309dtype = pl.List(pl.Array(pl.String, 2))310roundtrip_series_re([], dtype, **field)311roundtrip_series_re([None], dtype, **field)312roundtrip_series_re([[None]], dtype, **field)313roundtrip_series_re([[[None, None]]], dtype, **field)314roundtrip_series_re([[["a", "b"]]], dtype, **field)315roundtrip_series_re([[["a", "b"], ["xyz", "wowie"]]], dtype, **field)316roundtrip_series_re([[["a", "b"]], None, [None, None]], dtype, **field)317318319@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)320def test_list_struct_arr(field: Any) -> None:321dtype = pl.List(322pl.Struct({"x": pl.Array(pl.String, 2), "y": pl.Array(pl.Int64, 3)})323)324roundtrip_series_re([], dtype, **field)325roundtrip_series_re([None], dtype, **field)326roundtrip_series_re([[None]], dtype, **field)327roundtrip_series_re([[{"x": None, "y": None}]], dtype, **field)328roundtrip_series_re([[{"x": ["a", None], "y": [1, None, 3]}]], dtype, **field)329roundtrip_series_re([[{"x": ["a", "xyz"], "y": [1, 7, 3]}]], dtype, **field)330roundtrip_series_re([[{"x": ["a", "xyz"], "y": [1, 7, 3]}], []], dtype, **field)331332333@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)334def test_list_nulls(field: Any) -> None:335dtype = pl.List(pl.Null)336roundtrip_series_re([], dtype, **field)337roundtrip_series_re([[]], dtype, **field)338roundtrip_series_re([None], dtype, **field)339roundtrip_series_re([[None]], dtype, **field)340roundtrip_series_re([[None, None, None]], dtype, **field)341roundtrip_series_re([[None], [None, None], [None, None, None]], dtype, **field)342343344@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)345def test_masked_out_list_20151(field: Any) -> None:346dtype = pl.List(pl.Int64())347348values = [[1, 2], None, [4, 5], [None, 3]]349350array_series = pl.Series(values, dtype=pl.Array(pl.Int64(), 2))351list_from_array_series = array_series.cast(dtype)352353roundtrip_series_re(list_from_array_series, dtype, **field)354355356def test_int_after_null() -> None:357roundtrip_re(358pl.DataFrame(359[360pl.Series("a", [None], pl.Null),361pl.Series("b", [None], pl.Int8),362]363),364nulls_last=[True, True],365)366367368@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)369@given(s=series(allow_null=False, allow_chunks=False, excluded_dtypes=[pl.Categorical]))370def test_optional_eq_non_optional_20320(field: Any, s: pl.Series) -> None:371with_null = s.extend(pl.Series([None], dtype=s.dtype))372373re_without_null = s._row_encode(**field)374re_with_null = with_null._row_encode(**field)375376re_without_null = re_without_null.cast(pl.Binary)377re_with_null = re_with_null.cast(pl.Binary)378379assert_series_equal(re_with_null.head(s.len()), re_without_null)380381382@pytest.mark.parametrize("field", FIELD_COMBS_ARGS)383@given(dtype=dtypes(excluded_dtypes=[pl.Categorical]))384def test_null(385field: Any,386dtype: pl.DataType,387) -> None:388s = pl.Series("a", [None], dtype)389390assert_series_equal(391s._row_encode(**field)392._row_decode(393["a"],394[dtype],395descending=None if field["descending"] is None else [field["descending"]],396nulls_last=None if field["nulls_last"] is None else [field["nulls_last"]],397unordered=field["unordered"],398)399.struct.unnest()400.to_series(),401s,402)403404405@pytest.mark.parametrize(406("dtype", "vs"),407[408(pl.List(pl.String), [[None], ["A"], ["B"]]),409(pl.Array(pl.String, 1), [[None], ["A"], ["B"]]),410(pl.Struct({"x": pl.String}), [{"x": None}, {"x": "A"}, {"x": "B"}]),411(pl.Array(pl.String, 2), [[None, "Z"], ["A", "C"], ["B", "B"]]),412],413)414def test_nested_sorting_22557(dtype: pl.DataType, vs: list[Any]) -> None:415s = pl.Series("a", [vs[1], None, vs[0], vs[2]], dtype)416417assert_series_equal(418s.sort(descending=False, nulls_last=False), pl.Series("a", [None] + vs, dtype)419)420assert_series_equal(421s.sort(descending=False, nulls_last=True), pl.Series("a", vs + [None], dtype)422)423assert_series_equal(424s.sort(descending=True, nulls_last=False),425pl.Series("a", [None] + vs[::-1], dtype),426)427assert_series_equal(428s.sort(descending=True, nulls_last=True),429pl.Series("a", vs[::-1] + [None], dtype),430)431432roundtrip_series_re(vs, dtype, descending=False, nulls_last=False)433roundtrip_series_re(vs, dtype, descending=False, nulls_last=True)434roundtrip_series_re(vs, dtype, descending=True, nulls_last=False)435roundtrip_series_re(vs, dtype, descending=True, nulls_last=True)436437assert_series_equal(438s._row_encode(descending=False, nulls_last=False).arg_sort(),439pl.Series("a", [1, 2, 0, 3], pl.get_index_type()),440check_names=False,441)442assert_series_equal(443s._row_encode(descending=False, nulls_last=True).arg_sort(),444pl.Series("a", [2, 0, 3, 1], pl.get_index_type()),445check_names=False,446)447assert_series_equal(448s._row_encode(descending=True, nulls_last=False).arg_sort(),449pl.Series("a", [1, 3, 0, 2], pl.get_index_type()),450check_names=False,451)452assert_series_equal(453s._row_encode(descending=True, nulls_last=True).arg_sort(),454pl.Series("a", [3, 0, 2, 1], pl.get_index_type()),455check_names=False,456)457458459def test_row_encoding_null_chunks() -> None:460lf1 = pl.select(a=pl.lit(1, pl.Int64)).lazy()461lf2 = pl.select(a=None).lazy()462463lf = pl.concat([lf1, lf2]).select(pl.col.a._row_encode())464465out = (466lf.select(cs.all()._row_decode(["a"], [pl.Int64]))467.unnest(cs.all())468.collect(engine="streaming")469)470471assert_frame_equal(472pl.concat([lf1, lf2]).collect(),473out,474)475476477